From 434b0024e51216def82d992d55771430edbfa027 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 10:14:51 +0100 Subject: [PATCH 01/10] preparing for black run tool-setup in pyproject.toml, and fixed Periodic table to not be formatted. Signed-off-by: Nick Papior --- pyproject.toml | 8 +++++--- src/sisl/atom.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 63ff2ad930..47d2c26c42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -271,16 +271,18 @@ commands = pytest -s -rXs {posargs} [tool.isort] # how should sorting be done -profile = "hug" +profile = "black" sections = ['FUTURE', 'STDLIB', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER'] known_first_party = ["sisl_toolbox", "sisl"] -line_length = 90 +line_length = 88 overwrite_in_place = true extend_skip = ["src/sisl/__init__.py"] +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311", "py312"] -# Options for cibuildwheel [tool.cibuildwheel] build-verbosity = 3 test-extras = "test" diff --git a/src/sisl/atom.py b/src/sisl/atom.py index 8f05d9b4e4..d00423a85f 100644 --- a/src/sisl/atom.py +++ b/src/sisl/atom.py @@ -70,8 +70,9 @@ class PeriodicTable: True >>> 1.7 == PeriodicTable().radii(6,'vdw') True - """ + + # fmt: off _Z_int = { 'Actinium': 89, 'Ac': 89, '89': 89, 89: 89, 'Aluminum': 13, 'Al': 13, '13': 13, 13: 13, @@ -796,6 +797,7 @@ class PeriodicTable: 117: -1, 118: -1, } + # fmt: on def Z(self, key): """ Atomic number based on general input From 0ff17210003a85b08f683307e9a73555142c8707 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 10:22:46 +0100 Subject: [PATCH 02/10] ran black on all sources Signed-off-by: Nick Papior --- benchmarks/bloch.py | 4 +- benchmarks/bz_parallel.py | 7 +- benchmarks/bz_replace.py | 29 +- benchmarks/graphene.py | 6 +- benchmarks/graphene_io.py | 4 +- benchmarks/graphene_repeat.py | 6 +- benchmarks/graphene_tile.py | 6 +- benchmarks/graphene_within.py | 16 +- benchmarks/optimizations/array_arange.ipynb | 150 +- benchmarks/sparse_matrices.py | 4 +- benchmarks/sparse_matrices_ufunc.py | 4 +- benchmarks/stats.py | 7 +- developments/distributions.py | 88 +- developments/miller.py | 6 +- docs/conf.py | 174 +- docs/nodes/nodes_intro.ipynb | 39 +- docs/tutorials/tutorial_05_graphene.py | 10 +- docs/tutorials/tutorial_05_square.py | 20 +- docs/tutorials/tutorial_06_square.py | 19 +- docs/tutorials/tutorial_es_1.ipynb | 73 +- docs/tutorials/tutorial_es_2.ipynb | 87 +- docs/tutorials/tutorial_siesta_1.ipynb | 73 +- docs/tutorials/tutorial_siesta_2.ipynb | 63 +- .../viz_module/basic-tutorials/Demo.ipynb | 9 +- .../Intro to combining plots.ipynb | 40 +- .../viz_module/diy/Adding new backends.ipynb | 21 +- .../viz_module/diy/Building a new plot.ipynb | 2 - .../viz_module/showcase/BandsPlot.ipynb | 57 +- .../viz_module/showcase/FatbandsPlot.ipynb | 19 +- .../viz_module/showcase/GeometryPlot.ipynb | 91 +- .../viz_module/showcase/GridPlot.ipynb | 21 +- .../viz_module/showcase/PdosPlot.ipynb | 69 +- .../viz_module/showcase/SitesPlot.ipynb | 16 +- .../showcase/WavefunctionPlot.ipynb | 16 +- .../viz_module/showcase/_template/create.py | 4 +- .../viz_module/tests/test_tutorials.py | 41 +- examples/ex_01.py | 6 +- examples/ex_02.py | 6 +- examples/ex_03.py | 30 +- src/sisl/__init__.py | 18 +- src/sisl/_array.py | 34 +- src/sisl/_category.py | 82 +- src/sisl/_common.py | 3 +- src/sisl/_dispatch_class.py | 20 +- src/sisl/_dispatcher.py | 265 ++- src/sisl/_environ.py | 96 +- src/sisl/_help.py | 59 +- src/sisl/_internal.py | 2 + src/sisl/_namedindex.py | 64 +- src/sisl/_plot.py | 14 +- src/sisl/_typing.py | 2 +- src/sisl/_typing_ext/numpy.py | 2 +- src/sisl/atom.py | 333 ++-- src/sisl/conftest.py | 189 +- src/sisl/constant.py | 20 +- src/sisl/geom/_common.py | 6 +- src/sisl/geom/_composite.py | 12 +- src/sisl/geom/basic.py | 104 +- src/sisl/geom/bilayer.py | 62 +- src/sisl/geom/category/_coord.py | 98 +- src/sisl/geom/category/_kind.py | 26 +- src/sisl/geom/category/_neighbours.py | 19 +- src/sisl/geom/category/base.py | 4 +- .../geom/category/tests/test_geom_category.py | 16 +- src/sisl/geom/flat.py | 103 +- src/sisl/geom/nanoribbon.py | 220 ++- src/sisl/geom/nanotube.py | 20 +- src/sisl/geom/special.py | 20 +- src/sisl/geom/surfaces.py | 304 +-- src/sisl/geom/tests/test_geom.py | 227 ++- src/sisl/geometry.py | 1657 ++++++++++------- src/sisl/grid.py | 566 +++--- src/sisl/io/_help.py | 15 +- src/sisl/io/_multiple.py | 83 +- src/sisl/io/bigdft/__init__.py | 2 +- src/sisl/io/bigdft/ascii.py | 48 +- src/sisl/io/cube.py | 93 +- src/sisl/io/fhiaims/__init__.py | 2 +- src/sisl/io/fhiaims/_geometry.py | 30 +- src/sisl/io/gulp/__init__.py | 2 +- src/sisl/io/gulp/fc.py | 19 +- src/sisl/io/gulp/got.py | 140 +- src/sisl/io/gulp/tests/test_gout.py | 12 +- src/sisl/io/ham.py | 129 +- src/sisl/io/molden.py | 26 +- src/sisl/io/openmx/__init__.py | 2 +- src/sisl/io/openmx/md.py | 2 +- src/sisl/io/openmx/omx.py | 151 +- src/sisl/io/openmx/sile.py | 2 +- src/sisl/io/orca/__init__.py | 2 +- src/sisl/io/orca/sile.py | 2 +- src/sisl/io/orca/stdout.py | 135 +- src/sisl/io/orca/tests/test_stdout.py | 222 ++- src/sisl/io/orca/tests/test_txt.py | 23 +- src/sisl/io/orca/txt.py | 58 +- src/sisl/io/pdb.py | 184 +- src/sisl/io/scaleup/__init__.py | 2 +- src/sisl/io/scaleup/orbocc.py | 13 +- src/sisl/io/scaleup/ref.py | 70 +- src/sisl/io/scaleup/rham.py | 24 +- src/sisl/io/siesta/__init__.py | 2 +- src/sisl/io/siesta/_help.py | 47 +- src/sisl/io/siesta/ani.py | 2 +- src/sisl/io/siesta/bands.py | 60 +- src/sisl/io/siesta/basis.py | 99 +- src/sisl/io/siesta/binaries.py | 791 +++++--- src/sisl/io/siesta/eig.py | 126 +- src/sisl/io/siesta/fa.py | 22 +- src/sisl/io/siesta/fc.py | 20 +- src/sisl/io/siesta/fdf.py | 648 ++++--- src/sisl/io/siesta/kp.py | 48 +- src/sisl/io/siesta/orb_indx.py | 14 +- src/sisl/io/siesta/pdos.py | 251 ++- src/sisl/io/siesta/siesta_grid.py | 105 +- src/sisl/io/siesta/siesta_nc.py | 476 +++-- src/sisl/io/siesta/sile.py | 22 +- src/sisl/io/siesta/stdout.py | 436 +++-- src/sisl/io/siesta/struct.py | 30 +- src/sisl/io/siesta/tests/test_ani.py | 11 +- src/sisl/io/siesta/tests/test_bands.py | 22 +- src/sisl/io/siesta/tests/test_basis.py | 14 +- src/sisl/io/siesta/tests/test_dm.py | 46 +- src/sisl/io/siesta/tests/test_eig.py | 4 +- src/sisl/io/siesta/tests/test_fa.py | 8 +- src/sisl/io/siesta/tests/test_fc.py | 22 +- src/sisl/io/siesta/tests/test_fdf.py | 397 ++-- src/sisl/io/siesta/tests/test_gf.py | 28 +- src/sisl/io/siesta/tests/test_grid.py | 16 +- src/sisl/io/siesta/tests/test_kp.py | 8 +- src/sisl/io/siesta/tests/test_orb_indx.py | 6 +- src/sisl/io/siesta/tests/test_pdos.py | 14 +- src/sisl/io/siesta/tests/test_siesta.py | 158 +- src/sisl/io/siesta/tests/test_stdout.py | 24 +- .../io/siesta/tests/test_stdout_charges.py | 15 +- src/sisl/io/siesta/tests/test_struct.py | 24 +- src/sisl/io/siesta/tests/test_tsde.py | 28 +- src/sisl/io/siesta/tests/test_tshs.py | 101 +- src/sisl/io/siesta/tests/test_wfsx.py | 8 +- src/sisl/io/siesta/tests/test_xv.py | 20 +- src/sisl/io/siesta/transiesta_grid.py | 56 +- src/sisl/io/siesta/xv.py | 38 +- src/sisl/io/sile.py | 395 ++-- src/sisl/io/table.py | 83 +- src/sisl/io/tbtrans/__init__.py | 2 +- src/sisl/io/tbtrans/_cdf.py | 228 ++- src/sisl/io/tbtrans/binaries.py | 12 +- src/sisl/io/tbtrans/delta.py | 314 ++-- src/sisl/io/tbtrans/pht.py | 52 +- src/sisl/io/tbtrans/phtproj.py | 11 +- src/sisl/io/tbtrans/se.py | 105 +- src/sisl/io/tbtrans/sile.py | 3 +- src/sisl/io/tbtrans/tbt.py | 1253 ++++++++----- src/sisl/io/tbtrans/tbtproj.py | 369 ++-- src/sisl/io/tbtrans/tests/test_delta.py | 107 +- src/sisl/io/tbtrans/tests/test_tbt.py | 339 ++-- src/sisl/io/tbtrans/tests/test_tbtproj.py | 183 +- src/sisl/io/tests/test_cube.py | 52 +- src/sisl/io/tests/test_object.py | 120 +- src/sisl/io/tests/test_table.py | 60 +- src/sisl/io/tests/test_tb.py | 10 +- src/sisl/io/tests/test_xsf.py | 45 +- src/sisl/io/tests/test_xyz.py | 59 +- src/sisl/io/vasp/__init__.py | 2 +- src/sisl/io/vasp/car.py | 80 +- src/sisl/io/vasp/chg.py | 8 +- src/sisl/io/vasp/doscar.py | 18 +- src/sisl/io/vasp/eigenval.py | 13 +- src/sisl/io/vasp/locpot.py | 12 +- src/sisl/io/vasp/sile.py | 4 +- src/sisl/io/vasp/stdout.py | 74 +- src/sisl/io/vasp/tests/test_car.py | 44 +- src/sisl/io/vasp/tests/test_chg.py | 8 +- src/sisl/io/vasp/tests/test_doscar.py | 4 +- src/sisl/io/vasp/tests/test_eigenval.py | 4 +- src/sisl/io/vasp/tests/test_locpot.py | 6 +- src/sisl/io/vasp/tests/test_stdout.py | 20 +- src/sisl/io/wannier90/__init__.py | 2 +- src/sisl/io/wannier90/seedname.py | 135 +- src/sisl/io/wannier90/tests/test_seedname.py | 48 +- src/sisl/io/xsf.py | 215 ++- src/sisl/io/xyz.py | 48 +- src/sisl/lattice.py | 416 +++-- src/sisl/linalg/base.py | 101 +- src/sisl/linalg/special.py | 8 +- src/sisl/linalg/tests/test_solve.py | 4 +- src/sisl/messages.py | 77 +- src/sisl/mixing/base.py | 89 +- src/sisl/mixing/diis.py | 70 +- src/sisl/mixing/linear.py | 14 +- src/sisl/mixing/tests/test_linear.py | 5 +- src/sisl/nodes/context.py | 17 +- src/sisl/nodes/dispatcher.py | 18 +- src/sisl/nodes/node.py | 215 ++- src/sisl/nodes/syntax_nodes.py | 11 +- src/sisl/nodes/tests/test_context.py | 27 +- src/sisl/nodes/tests/test_node.py | 61 +- src/sisl/nodes/tests/test_syntax_nodes.py | 10 +- src/sisl/nodes/tests/test_utils.py | 48 +- src/sisl/nodes/tests/test_workflow.py | 68 +- src/sisl/nodes/utils.py | 60 +- src/sisl/nodes/workflow.py | 552 ++++-- src/sisl/oplist.py | 27 +- src/sisl/orbital.py | 437 +++-- src/sisl/physics/__init__.py | 1 + src/sisl/physics/_brillouinzone_apply.py | 215 ++- src/sisl/physics/_feature.py | 6 +- src/sisl/physics/bloch.py | 26 +- src/sisl/physics/brillouinzone.py | 406 ++-- src/sisl/physics/densitymatrix.py | 230 ++- src/sisl/physics/distribution.py | 68 +- src/sisl/physics/dynamicalmatrix.py | 49 +- src/sisl/physics/electron.py | 368 ++-- src/sisl/physics/energydensitymatrix.py | 36 +- src/sisl/physics/hamiltonian.py | 76 +- src/sisl/physics/overlap.py | 18 +- src/sisl/physics/phonon.py | 76 +- src/sisl/physics/self_energy.py | 523 ++++-- src/sisl/physics/sparse.py | 219 ++- src/sisl/physics/spin.py | 92 +- src/sisl/physics/state.py | 336 ++-- src/sisl/physics/tests/test_bloch.py | 14 +- src/sisl/physics/tests/test_brillouinzone.py | 287 +-- src/sisl/physics/tests/test_density_matrix.py | 485 +++-- src/sisl/physics/tests/test_distribution.py | 24 +- .../physics/tests/test_dynamical_matrix.py | 75 +- .../tests/test_energy_density_matrix.py | 101 +- src/sisl/physics/tests/test_feature.py | 8 +- src/sisl/physics/tests/test_hamiltonian.py | 842 +++++---- src/sisl/physics/tests/test_overlap.py | 31 +- src/sisl/physics/tests/test_physics_sparse.py | 226 ++- src/sisl/physics/tests/test_self_energy.py | 128 +- src/sisl/physics/tests/test_spin.py | 62 +- src/sisl/physics/tests/test_state.py | 39 +- src/sisl/quaternion.py | 91 +- src/sisl/selector.py | 56 +- src/sisl/shape/__init__.py | 2 +- src/sisl/shape/_cylinder.py | 47 +- src/sisl/shape/base.py | 207 +- src/sisl/shape/ellipsoid.py | 105 +- src/sisl/shape/prism4.py | 85 +- src/sisl/shape/tests/test_cylinder.py | 18 +- src/sisl/shape/tests/test_ellipsoid.py | 104 +- src/sisl/shape/tests/test_prism4.py | 100 +- src/sisl/shape/tests/test_shape.py | 96 +- src/sisl/sparse.py | 469 +++-- src/sisl/sparse_geometry.py | 701 ++++--- src/sisl/tests/test_atom.py | 88 +- src/sisl/tests/test_atoms.py | 216 +-- src/sisl/tests/test_geometry.py | 513 ++--- src/sisl/tests/test_geometry_return.py | 36 +- src/sisl/tests/test_grid.py | 79 +- src/sisl/tests/test_help.py | 15 +- src/sisl/tests/test_lattice.py | 110 +- src/sisl/tests/test_messages.py | 24 +- src/sisl/tests/test_namedindex.py | 37 +- src/sisl/tests/test_oplist.py | 11 +- src/sisl/tests/test_orbital.py | 139 +- src/sisl/tests/test_plot.py | 35 +- src/sisl/tests/test_quaternion.py | 26 +- src/sisl/tests/test_selector.py | 12 +- src/sisl/tests/test_sgeom.py | 112 +- src/sisl/tests/test_sgrid.py | 134 +- src/sisl/tests/test_sparse.py | 251 +-- src/sisl/tests/test_sparse_geometry.py | 163 +- src/sisl/tests/test_sparse_orbital.py | 128 +- src/sisl/typing/_common.py | 8 +- src/sisl/typing/tests/test_typing.py | 3 +- src/sisl/unit/base.py | 149 +- src/sisl/unit/siesta.py | 107 +- src/sisl/unit/tests/test_unit.py | 60 +- src/sisl/unit/tests/test_unit_siesta.py | 32 +- src/sisl/utils/_arrays.py | 32 +- src/sisl/utils/_sisl_cmd.py | 36 +- src/sisl/utils/cmd.py | 79 +- src/sisl/utils/mathematics.py | 32 +- src/sisl/utils/misc.py | 54 +- src/sisl/utils/ranges.py | 26 +- src/sisl/utils/tests/test_cmd.py | 58 +- src/sisl/utils/tests/test_misc.py | 85 +- src/sisl/utils/tests/test_ranges.py | 148 +- src/sisl/viz/__init__.py | 9 +- src/sisl/viz/_plotables.py | 230 ++- src/sisl/viz/_plotables_register.py | 28 +- src/sisl/viz/_presets.py | 4 +- src/sisl/viz/_single_dispatch.py | 7 +- src/sisl/viz/_splot.py | 161 +- src/sisl/viz/_xarray_accessor.py | 17 +- src/sisl/viz/data/bands.py | 312 ++-- src/sisl/viz/data/data.py | 12 +- src/sisl/viz/data/eigenstate.py | 38 +- src/sisl/viz/data/pdos.py | 159 +- src/sisl/viz/data/sisl_objs.py | 15 +- src/sisl/viz/data/tests/conftest.py | 5 +- src/sisl/viz/data/tests/test_bands.py | 47 +- src/sisl/viz/data/tests/test_pdos.py | 41 +- src/sisl/viz/data/xarray.py | 5 +- src/sisl/viz/data_sources/atom_data.py | 39 +- src/sisl/viz/data_sources/bond_data.py | 32 +- src/sisl/viz/data_sources/data_source.py | 9 +- src/sisl/viz/data_sources/eigenstate_data.py | 11 +- src/sisl/viz/data_sources/file/__init__.py | 2 +- src/sisl/viz/data_sources/file/file_source.py | 15 +- src/sisl/viz/data_sources/file/siesta.py | 19 +- .../viz/data_sources/hamiltonian_source.py | 7 +- src/sisl/viz/data_sources/orbital_data.py | 23 +- src/sisl/viz/figure/__init__.py | 14 +- src/sisl/viz/figure/blender.py | 288 ++- src/sisl/viz/figure/figure.py | 464 +++-- src/sisl/viz/figure/matplotlib.py | 250 ++- src/sisl/viz/figure/plotly.py | 725 +++++--- src/sisl/viz/figure/py3dmol.py | 164 +- src/sisl/viz/plot.py | 12 +- src/sisl/viz/plots/bands.py | 174 +- src/sisl/viz/plots/geometry.py | 230 ++- src/sisl/viz/plots/grid.py | 185 +- src/sisl/viz/plots/merged.py | 11 +- src/sisl/viz/plots/orbital_groups_plot.py | 62 +- src/sisl/viz/plots/pdos.py | 36 +- src/sisl/viz/plots/tests/test_bands.py | 8 +- src/sisl/viz/plots/tests/test_geometry.py | 4 +- src/sisl/viz/plots/tests/test_grid.py | 4 +- src/sisl/viz/plots/tests/test_pdos.py | 7 +- src/sisl/viz/plotters/__init__.py | 2 +- src/sisl/viz/plotters/cell.py | 31 +- src/sisl/viz/plotters/grid.py | 47 +- src/sisl/viz/plotters/plot_actions.py | 26 +- src/sisl/viz/plotters/tests/test_xarray.py | 7 +- src/sisl/viz/plotters/xarray.py | 159 +- src/sisl/viz/plotutils.py | 156 +- src/sisl/viz/processors/__init__.py | 2 +- src/sisl/viz/processors/atom.py | 52 +- src/sisl/viz/processors/axes.py | 31 +- src/sisl/viz/processors/bands.py | 182 +- src/sisl/viz/processors/cell.py | 119 +- src/sisl/viz/processors/coords.py | 90 +- src/sisl/viz/processors/data.py | 20 +- src/sisl/viz/processors/eigenstate.py | 72 +- src/sisl/viz/processors/geometry.py | 203 +- src/sisl/viz/processors/grid.py | 270 ++- src/sisl/viz/processors/logic.py | 11 +- src/sisl/viz/processors/math.py | 2 +- src/sisl/viz/processors/orbital.py | 363 ++-- src/sisl/viz/processors/spin.py | 25 +- src/sisl/viz/processors/tests/test_axes.py | 21 +- src/sisl/viz/processors/tests/test_bands.py | 100 +- src/sisl/viz/processors/tests/test_cell.py | 36 +- src/sisl/viz/processors/tests/test_coords.py | 53 +- src/sisl/viz/processors/tests/test_data.py | 6 +- .../viz/processors/tests/test_eigenstate.py | 29 +- .../viz/processors/tests/test_geometry.py | 71 +- src/sisl/viz/processors/tests/test_grid.py | 155 +- .../viz/processors/tests/test_groupreduce.py | 249 ++- src/sisl/viz/processors/tests/test_logic.py | 6 +- src/sisl/viz/processors/tests/test_math.py | 2 - src/sisl/viz/processors/tests/test_orbital.py | 74 +- .../processors/tests/test_sci_groupreduce.py | 66 +- src/sisl/viz/processors/tests/test_spin.py | 2 - src/sisl/viz/processors/wavefunction.py | 59 +- src/sisl/viz/processors/xarray.py | 69 +- src/sisl/viz/types.py | 38 +- src/sisl_toolbox/btd/_btd.py | 550 +++--- src/sisl_toolbox/cli/__init__.py | 9 +- src/sisl_toolbox/models/_base.py | 28 +- src/sisl_toolbox/models/_graphene/__init__.py | 2 +- src/sisl_toolbox/models/_graphene/_base.py | 7 +- .../models/_graphene/_hamiltonian.py | 83 +- src/sisl_toolbox/siesta/atom/_atom.py | 243 ++- .../siesta/minimizer/_atom_basis.py | 154 +- .../siesta/minimizer/_atom_pseudo.py | 56 +- src/sisl_toolbox/siesta/minimizer/_metric.py | 5 +- .../siesta/minimizer/_metric_siesta.py | 97 +- .../siesta/minimizer/_minimize.py | 122 +- .../siesta/minimizer/_minimize_siesta.py | 44 +- src/sisl_toolbox/siesta/minimizer/_runner.py | 118 +- .../siesta/minimizer/_variable.py | 34 +- .../siesta/minimizer/_yaml_reader.py | 13 +- .../transiesta/poisson/fftpoisson_fix.py | 336 +++- tools/changelog.py | 14 +- tools/codata.py | 66 +- 379 files changed, 22872 insertions(+), 14738 deletions(-) diff --git a/benchmarks/bloch.py b/benchmarks/bloch.py index 7be5e9a167..4ade52ea29 100644 --- a/benchmarks/bloch.py +++ b/benchmarks/bloch.py @@ -51,6 +51,6 @@ stat = pstats.Stats(pr) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/benchmarks/bz_parallel.py b/benchmarks/bz_parallel.py index 132f094401..c2fbc80b73 100644 --- a/benchmarks/bz_parallel.py +++ b/benchmarks/bz_parallel.py @@ -52,8 +52,11 @@ par = bz.apply.renew(eta=True) E = np.linspace(-2, 2, 200) + + def wrap_DOS(es): return es.DOS(E) -dos = par.ndarray.eigenstate(wrap=wrap_DOS) -#dos = par.average.eigenstate(wrap=wrap_DOS) + +dos = par.ndarray.eigenstate(wrap=wrap_DOS) +# dos = par.average.eigenstate(wrap=wrap_DOS) diff --git a/benchmarks/bz_replace.py b/benchmarks/bz_replace.py index ef848d39ca..a127ca3f6d 100644 --- a/benchmarks/bz_replace.py +++ b/benchmarks/bz_replace.py @@ -39,12 +39,16 @@ nks = [50 for _ in range(nlvls)] if ndim == 1: + def get_nk(nk): return [nk, 1, 1] + ns = [2 for _ in range(nlvls)] else: + def get_nk(nk): return [nk, nk, 1] + if nlvls > 2: ns = [50 for _ in range(nlvls)] else: @@ -55,11 +59,12 @@ def get_nk(nk): def yield_kpoint(bz, n): yield from np.unique(np.random.randint(len(bz), size=n))[::-1] + # Replacement function def add_levels(bz, nks, ns, fast=False, as_index=False, debug=False): - """ Add different levels according to the length of `ns` """ + """Add different levels according to the length of `ns`""" global nlvls lvl = nlvls - len(nks) @@ -72,6 +77,7 @@ def add_levels(bz, nks, ns, fast=False, as_index=False, debug=False): bz = bz.copy() from io import StringIO + s = StringIO() def print_s(force=True): @@ -84,12 +90,12 @@ def print_s(force=True): # reset s s = StringIO() + if debug: print(f"lvl = {lvl}", file=s) print_s() if len(nks) > 0: - # calculate the size of the current BZ dsize = bz._size / bz._diag @@ -115,7 +121,9 @@ def print_s(force=True): # create the single monkhorst pack we will use for replacements if fast: new_bz = sisl.MonkhorstPack(bz.parent, nk, size=dsize, trs=False) - new, reps = add_levels(new_bz, nks[:-1], ns[:-1], fast, as_index, debug=debug) + new, reps = add_levels( + new_bz, nks[:-1], ns[:-1], fast, as_index, debug=debug + ) if as_index: bz.replace(iks, new, displacement=True, as_index=True) @@ -137,8 +145,12 @@ def print_s(force=True): # Recursively add a new level # create the single monkhorst pack we will use for replacements - new_bz = sisl.MonkhorstPack(bz.parent, nk, size=dsize, trs=False, displacement=k) - new, reps = add_levels(new_bz, nks[:-1], ns[:-1], fast, as_index, debug=debug) + new_bz = sisl.MonkhorstPack( + bz.parent, nk, size=dsize, trs=False, displacement=k + ) + new, reps = add_levels( + new_bz, nks[:-1], ns[:-1], fast, as_index, debug=debug + ) # calculate number of replaced k-points if debug: @@ -152,6 +164,7 @@ def print_s(force=True): if False: import matplotlib.pyplot as plt + plt.figure() plt.scatter(bz.k[:, 0], bz.k[:, 1]) plt.title(f"{lvl} and {ik}") @@ -166,7 +179,7 @@ def print_s(force=True): rep_nk = len(new) - (len(bz) - bz_nk) print("replaced k-points ", rep_nk, file=s) print_s() - #print(len(bz)*4 * 8 / 1024**3) + # print(len(bz)*4 * 8 / 1024**3) del new @@ -200,6 +213,6 @@ def print_s(force=True): stat = pstats.Stats(pr) # We sort against total-time - stat.sort_stats('tottime') + stat.sort_stats("tottime") # Only print the first 20% of the routines. - stat.print_stats('sisl', 0.2) + stat.print_stats("sisl", 0.2) diff --git a/benchmarks/graphene.py b/benchmarks/graphene.py index 1db3a17c24..8e2600a5c6 100644 --- a/benchmarks/graphene.py +++ b/benchmarks/graphene.py @@ -37,7 +37,7 @@ gr = sisl.geom.graphene(orthogonal=True).tile(N, 0).tile(N, 1) H = sisl.Hamiltonian(gr) pr.enable() -H.construct([(0.1, 1.44), (0., -2.7)], eta=True) +H.construct([(0.1, 1.44), (0.0, -2.7)], eta=True) H.finalize() pr.disable() pr.dump_stats(f"{sys.argv[0]}.profile") @@ -45,6 +45,6 @@ stat = pstats.Stats(pr) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/benchmarks/graphene_io.py b/benchmarks/graphene_io.py index 416487a84e..0277cf50df 100644 --- a/benchmarks/graphene_io.py +++ b/benchmarks/graphene_io.py @@ -47,6 +47,6 @@ stat = pstats.Stats(pr) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/benchmarks/graphene_repeat.py b/benchmarks/graphene_repeat.py index 6ab18c760b..c38dcc4afa 100644 --- a/benchmarks/graphene_repeat.py +++ b/benchmarks/graphene_repeat.py @@ -36,7 +36,7 @@ gr = sisl.geom.graphene(orthogonal=True) H = sisl.Hamiltonian(gr) -H.construct([(0.1, 1.44), (0., -2.7)]) +H.construct([(0.1, 1.44), (0.0, -2.7)]) pr.enable() H.repeat(N, 0).repeat(N, 1) H.finalize() @@ -46,6 +46,6 @@ stat = pstats.Stats(pr) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/benchmarks/graphene_tile.py b/benchmarks/graphene_tile.py index c4f28eae16..8e816f3e80 100644 --- a/benchmarks/graphene_tile.py +++ b/benchmarks/graphene_tile.py @@ -36,7 +36,7 @@ gr = sisl.geom.graphene(orthogonal=True) H = sisl.Hamiltonian(gr) -H.construct([(0.1, 1.44), (0., -2.7)]) +H.construct([(0.1, 1.44), (0.0, -2.7)]) pr.enable() H = H.tile(N, 0).tile(N, 1) H.finalize() @@ -46,6 +46,6 @@ stat = pstats.Stats(pr) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/benchmarks/graphene_within.py b/benchmarks/graphene_within.py index b27362519f..f9f4f5243c 100644 --- a/benchmarks/graphene_within.py +++ b/benchmarks/graphene_within.py @@ -21,13 +21,13 @@ import sisl -method = 'cube' -if 'cube' in sys.argv: - method = 'cube' - sys.argv.remove('cube') -elif 'sphere' in sys.argv: - method = 'sphere' - sys.argv.remove('sphere') +method = "cube" +if "cube" in sys.argv: + method = "cube" + sys.argv.remove("cube") +elif "sphere" in sys.argv: + method = "sphere" + sys.argv.remove("sphere") if len(sys.argv) > 1: N = int(sys.argv[1]) @@ -37,5 +37,5 @@ gr = sisl.geom.graphene(orthogonal=True).tile(N, 0).tile(N, 1) H = sisl.Hamiltonian(gr) -H.construct([(0.1, 1.44), (0., -2.7)], method=method, eta=True) +H.construct([(0.1, 1.44), (0.0, -2.7)], method=method, eta=True) H.finalize() diff --git a/benchmarks/optimizations/array_arange.ipynb b/benchmarks/optimizations/array_arange.ipynb index 6ad9f1bbfa..aa43076f76 100644 --- a/benchmarks/optimizations/array_arange.ipynb +++ b/benchmarks/optimizations/array_arange.ipynb @@ -9,10 +9,8 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", @@ -20,6 +18,7 @@ "from itertools import izip as zip, imap as map\n", "from numpy import arange\n", "import numpy as np\n", + "\n", "onesi = partial(np.ones, dtype=np.int32)\n", "cumsumi = partial(np.cumsum, dtype=np.int32)\n", "%matplotlib inline" @@ -27,10 +26,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def aarng(start, end):\n", @@ -39,10 +36,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def aarng_for_list(start, end):\n", @@ -51,10 +46,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def aarng_for_gen(start, end):\n", @@ -63,10 +56,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def aarng_concat(start, end):\n", @@ -75,10 +66,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def aarng_concat_list(start, end):\n", @@ -87,10 +76,8 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def aarng_concat_tuple(start, end):\n", @@ -99,18 +86,9 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 3 4 5 8 50 51 52 53 80 81]\n", - "[ 3 4 5 8 50 51 52 53 80 81]\n" - ] - } - ], + "outputs": [], "source": [ "def aarng_one(start, end):\n", " n = end - start\n", @@ -126,10 +104,8 @@ }, { "cell_type": "code", - "execution_count": 21, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "# Samples\n", @@ -140,119 +116,63 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 loop, best of 3: 1.17 s per loop\n" - ] - } - ], + "outputs": [], "source": [ "%timeit np.array((aarng(start, end)))" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 loop, best of 3: 1.15 s per loop\n" - ] - } - ], + "outputs": [], "source": [ "%timeit np.array(aarng_for_list(start, end))" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 loop, best of 3: 1.21 s per loop\n" - ] - } - ], + "outputs": [], "source": [ "%timeit np.array(aarng_for_gen(start, end))" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 loop, best of 3: 992 ms per loop\n" - ] - } - ], + "outputs": [], "source": [ "%timeit np.array(aarng_concat(start, end))" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 loop, best of 3: 985 ms per loop\n" - ] - } - ], + "outputs": [], "source": [ "%timeit np.array(aarng_concat_list(start, end))" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 loop, best of 3: 989 ms per loop\n" - ] - } - ], + "outputs": [], "source": [ "%timeit np.array(aarng_concat_tuple(start, end))" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 loop, best of 3: 550 ms per loop\n" - ] - } - ], + "outputs": [], "source": [ "%timeit np.array(aarng_one(start, end))" ] @@ -260,9 +180,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [] } diff --git a/benchmarks/sparse_matrices.py b/benchmarks/sparse_matrices.py index 27da237d2a..7b9e591d16 100644 --- a/benchmarks/sparse_matrices.py +++ b/benchmarks/sparse_matrices.py @@ -54,6 +54,6 @@ stat = pstats.Stats(pr) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/benchmarks/sparse_matrices_ufunc.py b/benchmarks/sparse_matrices_ufunc.py index 09e09536c6..e5a5598f27 100644 --- a/benchmarks/sparse_matrices_ufunc.py +++ b/benchmarks/sparse_matrices_ufunc.py @@ -64,6 +64,6 @@ stat = pstats.Stats(pr) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/benchmarks/stats.py b/benchmarks/stats.py index 87a917105f..5c637215c0 100644 --- a/benchmarks/stats.py +++ b/benchmarks/stats.py @@ -4,6 +4,7 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. import pstats + # Script for analysing profile scripts created by the # cProfile module. import sys @@ -11,11 +12,11 @@ if len(sys.argv) > 1: fname = sys.argv[1] else: - raise ValueError('Must supply a profile file-name') + raise ValueError("Must supply a profile file-name") stat = pstats.Stats(fname) # We sort against total-time -stat.sort_stats('tottime') +stat.sort_stats("tottime") # Only print the first 20% of the routines. -stat.print_stats('sisl', 0.2) +stat.print_stats("sisl", 0.2) diff --git a/developments/distributions.py b/developments/distributions.py index ef62d1836b..ea125d394e 100644 --- a/developments/distributions.py +++ b/developments/distributions.py @@ -42,29 +42,41 @@ def __eq__(self, name): distributions = [ - Distribution("fd", - pdf=(1-1/(sy.exp(x)+1)).diff(x), - sf=1/(sy.exp(x)+1)), - Distribution("mp", - pdf=sy.exp(-(x)**2)/sy.sqrt(sy.pi)*sy.Sum(hermite(2*n, x), (n, 0, N)), - cdf=1/sy.sqrt(sy.pi)*sy.Sum( - (sy.exp(-(x)**2)*hermite(2*n, x)) - .integrate(x) - .doit(simplify=True) - .expand() - .simplify(), (n, 0, N)), - entropy=-1/sy.sqrt(sy.pi)*sy.Sum( - (sy.exp(-(x)**2)*hermite(2*n, x) * x) - .integrate((x, -sy.oo, y)).subs(y, x) - .doit(simplify=True) - .expand() - .simplify(), (n, 0, N))), - Distribution("gaussian", - pdf=sy.exp(-(x)**2/2)/sy.sqrt(2*sy.pi)), - Distribution("cauchy", - pdf=1/(sy.pi*(1+(x)**2))), - Distribution("cold", - pdf=1/sy.sqrt(sy.pi)*sy.exp(-(-x-1/sy.sqrt(2))**2)*(2+sy.sqrt(2)*x)) + Distribution("fd", pdf=(1 - 1 / (sy.exp(x) + 1)).diff(x), sf=1 / (sy.exp(x) + 1)), + Distribution( + "mp", + pdf=sy.exp(-((x) ** 2)) / sy.sqrt(sy.pi) * sy.Sum(hermite(2 * n, x), (n, 0, N)), + cdf=1 + / sy.sqrt(sy.pi) + * sy.Sum( + (sy.exp(-((x) ** 2)) * hermite(2 * n, x)) + .integrate(x) + .doit(simplify=True) + .expand() + .simplify(), + (n, 0, N), + ), + entropy=-1 + / sy.sqrt(sy.pi) + * sy.Sum( + (sy.exp(-((x) ** 2)) * hermite(2 * n, x) * x) + .integrate((x, -sy.oo, y)) + .subs(y, x) + .doit(simplify=True) + .expand() + .simplify(), + (n, 0, N), + ), + ), + Distribution("gaussian", pdf=sy.exp(-((x) ** 2) / 2) / sy.sqrt(2 * sy.pi)), + Distribution("cauchy", pdf=1 / (sy.pi * (1 + (x) ** 2))), + Distribution( + "cold", + pdf=1 + / sy.sqrt(sy.pi) + * sy.exp(-((-x - 1 / sy.sqrt(2)) ** 2)) + * (2 + sy.sqrt(2) * x), + ), ] @@ -99,38 +111,44 @@ def __eq__(self, name): dist.cdf = (dist.cdf - cneg).expand().simplify() print(f" cdf = {dist.cdf}") - #print(f" cdf*= {dist.cdf.subs(dist.pdf, pdf)}") + # print(f" cdf*= {dist.cdf.subs(dist.pdf, pdf)}") # plot it... func = dist.pdf.subs(N, 0).subs(c, 1).expand().simplify() - func = sy.lambdify(x, func, 'numpy') + func = sy.lambdify(x, func, "numpy") axs[0].plot(E, func(E), label=dist.name) if dist.sf is None: dist.sf = (1 - dist.cdf).expand().simplify() print(f" sf|theta = {dist.sf}") - #print(f" d sf|theta = {-dist.pdf.expand().doit(simplify=True).simplify()}") - #print(f" sf|theta*= {dist.sf.subs(dist.pdf, pdf).subs(dist.cdf, cdf)}") + # print(f" d sf|theta = {-dist.pdf.expand().doit(simplify=True).simplify()}") + # print(f" sf|theta*= {dist.sf.subs(dist.pdf, pdf).subs(dist.cdf, cdf)}") func = dist.sf.subs(N, 0).subs(c, 1).expand().simplify() - func = sy.lambdify(x, func, 'numpy') + func = sy.lambdify(x, func, "numpy") try: # cold function may fail axs[1].plot(E, [func(e) for e in E], label=dist.name) - except Exception: pass + except Exception: + pass if dist.entropy is None: - dist.entropy = -(dist.pdf*x).integrate((x, -sy.oo, x)).doit(simplify=True).simplify() + dist.entropy = ( + -(dist.pdf * x).integrate((x, -sy.oo, x)).doit(simplify=True).simplify() + ) print(f" entropy = {dist.entropy}") func = dist.entropy.subs(N, 0).subs(c, 1).expand().simplify() - func = sy.lambdify(x, func, 'numpy') + func = sy.lambdify(x, func, "numpy") try: # cold function may fail axs[2].plot(E, [func(e) for e in E], label=dist.name) - except Exception: pass + except Exception: + pass - var = (dist.pdf*x*x).integrate((x, -sy.oo, sy.oo)).doit(simplify=True).simplify() + var = ( + (dist.pdf * x * x).integrate((x, -sy.oo, sy.oo)).doit(simplify=True).simplify() + ) print(f" variance = {var}") @@ -138,8 +156,8 @@ def __eq__(self, name): ifd = distributions.index("fd") fd = distributions[ifd] # Check that it finds the same entropy -fd_enpy = -(fd.sf * sy.log(fd.sf) + (1-fd.sf)*sy.log(1-fd.sf)) -assert (fd_enpy - fd.entropy).simplify() == 0. +fd_enpy = -(fd.sf * sy.log(fd.sf) + (1 - fd.sf) * sy.log(1 - fd.sf)) +assert (fd_enpy - fd.entropy).simplify() == 0.0 axs[0].legend() diff --git a/developments/miller.py b/developments/miller.py index ffd0a77322..32bbf0d680 100644 --- a/developments/miller.py +++ b/developments/miller.py @@ -60,13 +60,13 @@ def get_miller(rcell, hkl): r0, t0, p0 = si.utils.math.cart2spher(v0) # Create a rotation matrix that rotates the first vector to be along the # first lattice. - q0 = si.Quaternion(-p0, [0, 0, 1.], True) - q1 = si.Quaternion(-t0, [1., 0, 0], True) + q0 = si.Quaternion(-p0, [0, 0, 1.0], True) + q1 = si.Quaternion(-t0, [1.0, 0, 0], True) q = q0 * q1 rv = q.rotate(rv) # Remove too small numbers rv = np.where(np.abs(rv) < 1e-10, 0, rv) - v = np.linalg.inv(rv) * 2. * np.pi + v = np.linalg.inv(rv) * 2.0 * np.pi print(v) return q.rotate(v) diff --git a/docs/conf.py b/docs/conf.py index 08b846be39..6f6b7351f1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,12 +26,12 @@ _root = pathlib.Path(__file__).absolute().parent.parent / "src" # If building this on RTD, mock out fortran sources -on_rtd = os.environ.get('READTHEDOCS', 'false').lower() == 'true' +on_rtd = os.environ.get("READTHEDOCS", "false").lower() == "true" if on_rtd: os.environ["SISL_NUM_PROCS"] = "1" os.environ["SISL_VIZ_NUM_PROCS"] = "1" -#sys.path.insert(0, str(_root)) +# sys.path.insert(0, str(_root)) # Print standard information about executable and path... print("python exec:", sys.executable) @@ -62,7 +62,7 @@ "sphinx.ext.todo", "sphinx.ext.viewcode", "sphinx_autodoc_typehints", - "sphinxcontrib.jquery", # a bug in 4.1.0 means search didn't work without explicit extension + "sphinxcontrib.jquery", # a bug in 4.1.0 means search didn't work without explicit extension "sphinx_inline_tabs", # plotting and advanced usage "matplotlib.sphinxext.plot_directive", @@ -80,22 +80,22 @@ # The default is MathJax 3. # In case we want to revert to 2.7.7, then use the below link: -#mathjax_path = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML" +# mathjax_path = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML" # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # Short-hand for :doi: extlinks = { - 'issue': ('https://github.com/zerothi/sisl/issues/%s', 'issue %s'), - 'pull': ('https://github.com/zerothi/sisl/pull/%s', 'pull request %s'), - 'doi': ('https://doi.org/%s', '%s'), + "issue": ("https://github.com/zerothi/sisl/issues/%s", "issue %s"), + "pull": ("https://github.com/zerothi/sisl/pull/%s", "pull request %s"), + "doi": ("https://doi.org/%s", "%s"), } # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # prepend/append this snippet in _all_ sources rst_prolog = """ @@ -103,7 +103,7 @@ """ # Insert the links into the epilog (globally) # This means that every document has access to the links -rst_epilog = ''.join(open('epilog.dummy').readlines()) +rst_epilog = "".join(open("epilog.dummy").readlines()) autosummary_generate = True @@ -131,19 +131,25 @@ # Add __init__ classes to the documentation -autoclass_content = 'class' +autoclass_content = "class" autodoc_default_options = { - 'members': True, - 'undoc-members': True, - 'special-members': '__init__,__call__', - 'inherited-members': True, - 'show-inheritance': True, + "members": True, + "undoc-members": True, + "special-members": "__init__,__call__", + "inherited-members": True, + "show-inheritance": True, } # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['build', '**/setupegg.py', '**/setup.rst', '**/tests', '**.ipynb_checkpoints'] +exclude_patterns = [ + "build", + "**/setupegg.py", + "**/setup.rst", + "**/tests", + "**.ipynb_checkpoints", +] exclude_patterns.append("**/GUI with Python Demo.ipynb") exclude_patterns.append("**/Building a plot class.ipynb") for _venv in pathlib.Path(".").glob("*venv*"): @@ -151,7 +157,7 @@ # The reST default role (used for this markup: `text`) to use for all # documents. -default_role = 'autolink' +default_role = "autolink" # If true, '()' will be appended to :func: etc. cross-reference text. add_function_parentheses = False @@ -161,7 +167,7 @@ show_authors = False # A list of ignored prefixes for module index sorting. -modindex_common_prefix = ['sisl.'] +modindex_common_prefix = ["sisl."] # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -170,7 +176,7 @@ # -- Options for HTML output ---------------------------------------------- html_theme = "sphinx_rtd_theme" -#html_theme = "furo" +# html_theme = "furo" if html_theme == "furo": html_theme_options = { @@ -189,14 +195,14 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -if os.path.exists('_static'): - html_static_path = ['_static'] +if os.path.exists("_static"): + html_static_path = ["_static"] else: html_static_path = [] # Add any extra style files that we need html_css_files = [ - 'css/custom_styles.css', + "css/custom_styles.css", ] # If false, no index is generated. @@ -206,25 +212,21 @@ # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -'papersize': 'a4paper', - -# The font size ('10pt', '11pt' or '12pt'). -'pointsize': '11pt', - -# Additional stuff for the LaTeX preamble. -'preamble': r"", - -# Latex figure (float) alignment -'figure_align': '!htbp', + # The paper size ('letterpaper' or 'a4paper'). + "papersize": "a4paper", + # The font size ('10pt', '11pt' or '12pt'). + "pointsize": "11pt", + # Additional stuff for the LaTeX preamble. + "preamble": r"", + # Latex figure (float) alignment + "figure_align": "!htbp", } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ("index", 'sisl.tex', 'sisl Documentation', - 'Nick Papior', 'manual'), + ("index", "sisl.tex", "sisl Documentation", "Nick Papior", "manual"), ] ##### @@ -234,9 +236,9 @@ # These two options should solve the "toctree contains reference to nonexisting document" # problem. # See here: numpydoc #69 -#class_members_toctree = False +# class_members_toctree = False # If this is false we do not have double method sections -#numpydoc_show_class_members = False +# numpydoc_show_class_members = False # ----------------------------------------------------------------------------- # Intersphinx configuration @@ -244,14 +246,14 @@ # Python, numpy, scipy and matplotlib specify https as the default objects.inv # directory. So please retain these links. intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'scipy': ('https://docs.scipy.org/doc/scipy/', None), - 'matplotlib': ('https://matplotlib.org/stable/', None), - 'xarray': ('https://docs.xarray.dev/en/stable/', None), - 'plotly': ('https://plotly.com/python-api-reference/', None), - 'skimage': ('https://scikit-image.org/docs/stable', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable', None), + "python": ("https://docs.python.org/3/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), + "xarray": ("https://docs.xarray.dev/en/stable/", None), + "plotly": ("https://plotly.com/python-api-reference/", None), + "skimage": ("https://scikit-image.org/docs/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), } @@ -278,14 +280,13 @@ class YearMonthAuthorSortStyle(AYTSortingStyle): - def sorting_key(self, entry): ayt = super().sorting_key(entry) year = self._year_number(entry) month = self._month_number(entry) - return (-year, -month , ayt[0], ayt[2]) + return (-year, -month, ayt[0], ayt[2]) def _year_number(self, entry): year = entry.fields.get("year", 0) @@ -307,10 +308,13 @@ def _month_number(self, entry): class RevYearPlain(PlainStyle): default_sorting_style = "sort_rev_year" + import pybtex -pybtex.plugin.register_plugin('pybtex.style.sorting', 'sort_rev_year', YearMonthAuthorSortStyle) -pybtex.plugin.register_plugin('pybtex.style.formatting', 'rev_year', RevYearPlain) +pybtex.plugin.register_plugin( + "pybtex.style.sorting", "sort_rev_year", YearMonthAuthorSortStyle +) +pybtex.plugin.register_plugin("pybtex.style.formatting", "rev_year", RevYearPlain) # Tell nbsphinx to wait, at least X seconds for each cell nbsphinx_timeout = 600 @@ -341,12 +345,15 @@ def sisl_method2class(meth): if cls.__dict__.get(meth.__name__) is meth: return cls if inspect.isfunction(meth): - cls = getattr(inspect.getmodule(meth), - meth.__qualname__.split('.', 1)[0].rsplit('.', 1)[0]) + cls = getattr( + inspect.getmodule(meth), + meth.__qualname__.split(".", 1)[0].rsplit(".", 1)[0], + ) if isinstance(cls, type): return cls return None # not required since None would have been implicitly returned anyway + # My custom detailed instructions for not documenting stuff @@ -355,20 +362,31 @@ def sisl_skip(app, what, name, obj, skip, options): # When adding routines here, please also add them # to the _templates/autosummary/class.rst file to limit # the documentation. - if what == 'class': - if name in ['ArgumentParser', 'ArgumentParser_out', - 'is_keys', 'key2case', 'keys2case', - 'line_has_key', 'line_has_keys', 'readline', - 'step_to', - 'isDataset', 'isDimension', 'isGroup', - 'isRoot', 'isVariable']: + if what == "class": + if name in [ + "ArgumentParser", + "ArgumentParser_out", + "is_keys", + "key2case", + "keys2case", + "line_has_key", + "line_has_keys", + "readline", + "step_to", + "isDataset", + "isDimension", + "isGroup", + "isRoot", + "isVariable", + ]: return True - #elif what == "attribute": + # elif what == "attribute": # return True # check for special methods (we don't want all) - if (name.startswith("_") and - name not in autodoc_default_options.get("special-members", '').split(',')): + if name.startswith("_") and name not in autodoc_default_options.get( + "special-members", "" + ).split(","): return True try: @@ -384,20 +402,34 @@ def sisl_skip(app, what, name, obj, skip, options): # Apparently they will be linked directly. # Now we have some things to disable the output of if "projncSile" in cls.__name__: - if name in ["current", "current_parameter", "shot_noise", - "noise_power", "fano", "density_matrix", - "write_tbtav", - "orbital_COOP", "atom_COOP", - "orbital_COHP", "atom_COHP"]: + if name in [ + "current", + "current_parameter", + "shot_noise", + "noise_power", + "fano", + "density_matrix", + "write_tbtav", + "orbital_COOP", + "atom_COOP", + "orbital_COHP", + "atom_COHP", + ]: return True if "SilePHtrans" in cls.__name__: - if name in ["chemical_potential", "electron_temperature", - "kT", "current", "current_parameter", "shot_noise", - "noise_power"]: + if name in [ + "chemical_potential", + "electron_temperature", + "kT", + "current", + "current_parameter", + "shot_noise", + "noise_power", + ]: return True return None def setup(app): # Setup autodoc skipping - app.connect('autodoc-skip-member', sisl_skip) + app.connect("autodoc-skip-member", sisl_skip) diff --git a/docs/nodes/nodes_intro.ipynb b/docs/nodes/nodes_intro.ipynb index 448343dc9a..2ad9dc3d63 100644 --- a/docs/nodes/nodes_intro.ipynb +++ b/docs/nodes/nodes_intro.ipynb @@ -51,6 +51,7 @@ " print(f\"SUMMING {a} + {b}\")\n", " return a + b\n", "\n", + "\n", "# Instead of using it as a decorator, if you want to keep the pristine function,\n", "# you can always create the node later:\n", "#\n", @@ -172,7 +173,7 @@ "auto_result = my_sum(2, 5)\n", "\n", "auto_result.context.update(lazy=False)\n", - " \n", + "\n", "auto_result.get()\n", "auto_result.update_inputs(a=8)" ] @@ -198,7 +199,7 @@ "first_val = my_sum(2, 5)\n", "# Use the first value to compute our final value\n", "final_val = my_sum(first_val, 5)\n", - " \n", + "\n", "final_val.get()" ] }, @@ -310,16 +311,19 @@ "source": [ "from sisl.nodes import Workflow\n", "\n", + "\n", "def my_sum(a, b):\n", " print(f\"SUMMING {a} + {b}\")\n", " return a + b\n", "\n", + "\n", "# Define our workchain as a workflow.\n", "@Workflow.from_func\n", "def triple_sum(a: int, b: int, c: int):\n", " first_val = my_sum(a, b)\n", " return my_sum(first_val, c)\n", "\n", + "\n", "# Again, if you want to keep the pristine function,\n", "# don't use the decorator\n", "#\n", @@ -342,9 +346,7 @@ "cell_type": "code", "execution_count": null, "id": "1da1d4ec", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "wf_nodes = triple_sum.dryrun_nodes\n", @@ -522,10 +524,11 @@ "@Node.from_func(context={\"lazy\": False})\n", "def alert_change(val: int):\n", " print(f\"VALUE CHANGED, it now is {val}\")\n", - " \n", - "# We feed the node that produces the intermediate value into our alert node \n", + "\n", + "\n", + "# We feed the node that produces the intermediate value into our alert node\n", "my_alert = alert_change(result.nodes.first_val)\n", - " \n", + "\n", "# Now when we update the inputs of the workflow, the node will propagate the information through\n", "# our new node.\n", "result.update_inputs(a=10)" @@ -547,19 +550,18 @@ "outputs": [], "source": [ "class TripleSum(Workflow):\n", - " \n", " # Define the function that runs the workflow, exactly as we did before.\n", " @staticmethod\n", " def function(a: int, b: int, c: int):\n", " first_val = my_sum(a, b)\n", " return my_sum(first_val, c)\n", - " \n", + "\n", " # Now, we have the possibility of adding new methods to it.\n", " def scale(self, factor: int):\n", " self.update_inputs(\n", - " a=self.get_input('a')*factor,\n", - " b=self.get_input('b')*factor,\n", - " c=self.get_input('c')*factor\n", + " a=self.get_input(\"a\") * factor,\n", + " b=self.get_input(\"b\") * factor,\n", + " c=self.get_input(\"c\") * factor,\n", " )" ] }, @@ -613,9 +615,7 @@ "cell_type": "code", "execution_count": null, "id": "95c5e1d2", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "@Workflow.from_func\n", @@ -623,7 +623,10 @@ " val = a + b + c\n", " return val\n", "\n", - "sum_triple.network.visualize(notebook=True, )" + "\n", + "sum_triple.network.visualize(\n", + " notebook=True,\n", + ")" ] }, { @@ -644,11 +647,13 @@ "def operation(a, b, c):\n", " return a + b + c\n", "\n", + "\n", "@Workflow.from_func\n", "def sum_triple(a, b, c):\n", " val = operation(a, b, c)\n", " return val\n", "\n", + "\n", "sum_triple.network.visualize(notebook=True)" ] }, diff --git a/docs/tutorials/tutorial_05_graphene.py b/docs/tutorials/tutorial_05_graphene.py index fceaaccd88..4196b56fb7 100644 --- a/docs/tutorials/tutorial_05_graphene.py +++ b/docs/tutorials/tutorial_05_graphene.py @@ -22,7 +22,7 @@ print(H) # Create band-structure for the supercell. -band = BandStructure(H, [[0., 0.], [2./3, 1./3], [0.5, 0.5], [0., 0.]], 300) +band = BandStructure(H, [[0.0, 0.0], [2.0 / 3, 1.0 / 3], [0.5, 0.5], [0.0, 0.0]], 300) # Calculate eigenvalues of the band-structure eigs = band.eigh() @@ -31,12 +31,12 @@ import matplotlib.pyplot as plt plt.figure() -plt.title('Bandstructure of graphene, nearest neighbour') -plt.xlabel('k') -plt.ylabel('Eigenvalue') +plt.title("Bandstructure of graphene, nearest neighbour") +plt.xlabel("k") +plt.ylabel("Eigenvalue") # Generate linear-k for plotting (ensures correct spacing) lband = band.lineark() for i in range(eigs.shape[1]): plt.plot(lband, eigs[:, i]) -plt.savefig('05_graphene_bs.png') +plt.savefig("05_graphene_bs.png") diff --git a/docs/tutorials/tutorial_05_square.py b/docs/tutorials/tutorial_05_square.py index d191a19d43..fbfa3be21a 100644 --- a/docs/tutorials/tutorial_05_square.py +++ b/docs/tutorials/tutorial_05_square.py @@ -13,17 +13,17 @@ print(H) # Specify matrix elements (on-site and coupling elements) -H[0, 0] = -4. -H[0, 0, (1, 0)] = 1. -H[0, 0, (-1, 0)] = 1. -H[0, 0, (0, 1)] = 1. -H[0, 0, (0, -1)] = 1. +H[0, 0] = -4.0 +H[0, 0, (1, 0)] = 1.0 +H[0, 0, (-1, 0)] = 1.0 +H[0, 0, (0, 1)] = 1.0 +H[0, 0, (0, -1)] = 1.0 # Show that we indeed have added some code print(H) # Create band-structure for the supercell. -band = BandStructure(H, [[0., 0.], [0.5, 0.], [0.5, 0.5], [0., 0.]], 300) +band = BandStructure(H, [[0.0, 0.0], [0.5, 0.0], [0.5, 0.5], [0.0, 0.0]], 300) # Calculate eigenvalues of the band-structure eigs = band.eigh() @@ -32,12 +32,12 @@ import matplotlib.pyplot as plt plt.figure() -plt.title('Bandstructure of square, nearest neighbour') -plt.xlabel('k') -plt.ylabel('Eigenvalue') +plt.title("Bandstructure of square, nearest neighbour") +plt.xlabel("k") +plt.ylabel("Eigenvalue") # Generate linear-k for plotting (ensures correct spacing) lband = band.lineark() for i in range(eigs.shape[1]): plt.plot(lband, eigs[:, i]) -plt.savefig('05_square_bs.png') +plt.savefig("05_square_bs.png") diff --git a/docs/tutorials/tutorial_06_square.py b/docs/tutorials/tutorial_06_square.py index 8be7a4e172..209eafe15a 100644 --- a/docs/tutorials/tutorial_06_square.py +++ b/docs/tutorials/tutorial_06_square.py @@ -4,9 +4,8 @@ from sisl import * # Generate square lattice with nearest neighbour couplings -Hydrogen = Atom(1, R=1.) -square = Geometry([[0.5, 0.5, 0]], Hydrogen, - lattice=Lattice([1, 1, 10], [3, 3, 1])) +Hydrogen = Atom(1, R=1.0) +square = Geometry([[0.5, 0.5, 0]], Hydrogen, lattice=Lattice([1, 1, 10], [3, 3, 1])) # Generate Hamiltonian H = Hamiltonian(square) @@ -18,14 +17,14 @@ for ias, idxs in square.iter_block(): for ia in ias: idx_a = square.close(ia, R=[0.1, 1.1], atoms=idxs) - H[ia, idx_a[0]] = -4. - H[ia, idx_a[1]] = 1. + H[ia, idx_a[0]] = -4.0 + H[ia, idx_a[1]] = 1.0 # Show that we indeed have added some code print(H) # Create band-structure for the supercell. -band = BandStructure(H, [[0., 0.], [0.5, 0.], [0.5, 0.5], [0., 0.]], 300) +band = BandStructure(H, [[0.0, 0.0], [0.5, 0.0], [0.5, 0.5], [0.0, 0.0]], 300) # Calculate eigenvalues of the band-structure eigs = band.eigh() @@ -34,12 +33,12 @@ import matplotlib.pyplot as plt plt.figure() -plt.title('Bandstructure of square, nearest neighbour') -plt.xlabel('k') -plt.ylabel('Eigenvalue') +plt.title("Bandstructure of square, nearest neighbour") +plt.xlabel("k") +plt.ylabel("Eigenvalue") # Generate linear-k for plotting (ensures correct spacing) lband = band.lineark() for i in range(eigs.shape[1]): plt.plot(lband, eigs[:, i]) -plt.savefig('06_square_bs.png') +plt.savefig("06_square_bs.png") diff --git a/docs/tutorials/tutorial_es_1.ipynb b/docs/tutorials/tutorial_es_1.ipynb index cc2b1b4a21..eb011fa1ad 100644 --- a/docs/tutorials/tutorial_es_1.ipynb +++ b/docs/tutorials/tutorial_es_1.ipynb @@ -12,6 +12,7 @@ "from sisl.viz import merge_plots\n", "from sisl.viz.processors.math import normalize\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline" ] }, @@ -97,7 +98,9 @@ "outputs": [], "source": [ "coord = graphene.sub(0)\n", - "coord.plot(axes=\"xy\", atoms_style={\"color\": \"red\"}).merge(graphene.remove(0).plot(axes=\"xy\"))" + "coord.plot(axes=\"xy\", atoms_style={\"color\": \"red\"}).merge(\n", + " graphene.remove(0).plot(axes=\"xy\")\n", + ")" ] }, { @@ -115,15 +118,28 @@ "metadata": {}, "outputs": [], "source": [ - "xyz_center = graphene.center(what='xyz')\n", + "xyz_center = graphene.center(what=\"xyz\")\n", "indices = graphene.close(xyz_center, 1.5)\n", "index = indices[0]\n", "system = graphene.remove(index)\n", - "graphene.plot(axes=\"xy\", atoms_style=[\n", - " {\"opacity\": 0.5}, # Default style for all atoms\n", - " {\"atoms\": indices, \"color\": \"black\", \"size\": 1.2, \"opacity\": 1}, # Styling for indices_close_to_center on top of defaults.\n", - " {\"atoms\": index, \"color\": \"red\", \"size\": 1, \"opacity\": 1} # Styling for center_atom_index on top of defaults.\n", - "])" + "graphene.plot(\n", + " axes=\"xy\",\n", + " atoms_style=[\n", + " {\"opacity\": 0.5}, # Default style for all atoms\n", + " {\n", + " \"atoms\": indices,\n", + " \"color\": \"black\",\n", + " \"size\": 1.2,\n", + " \"opacity\": 1,\n", + " }, # Styling for indices_close_to_center on top of defaults.\n", + " {\n", + " \"atoms\": index,\n", + " \"color\": \"red\",\n", + " \"size\": 1,\n", + " \"opacity\": 1,\n", + " }, # Styling for center_atom_index on top of defaults.\n", + " ],\n", + ")" ] }, { @@ -163,8 +179,8 @@ "metadata": {}, "outputs": [], "source": [ - "r = (0.1, 1.44)\n", - "t = (0. , -2.7 )\n", + "r = (0.1, 1.44)\n", + "t = (0.0, -2.7)\n", "H.construct([r, t])\n", "print(H)" ] @@ -199,7 +215,7 @@ "es_fermi = es.sub(range(len(H) // 2 - 1, len(H) // 2 + 2))\n", "\n", "plots = [\n", - " system.plot(axes=\"xy\", atoms_style=[{\"size\": n * 20, \"color\": c}]) \n", + " system.plot(axes=\"xy\", atoms_style=[{\"size\": n * 20, \"color\": c}])\n", " for n, c in zip(es_fermi.norm2(sum=False), (\"red\", \"blue\", \"green\"))\n", "]\n", "\n", @@ -226,8 +242,9 @@ "outputs": [], "source": [ "E = np.linspace(-4, 4, 400)\n", - "plt.plot(E, es.DOS(E)); \n", - "plt.xlabel(r'$E - E_F$ [eV]'); plt.ylabel(r'DOS at $\\Gamma$ [1/eV]');" + "plt.plot(E, es.DOS(E))\n", + "plt.xlabel(r\"$E - E_F$ [eV]\")\n", + "plt.ylabel(r\"DOS at $\\Gamma$ [1/eV]\");" ] }, { @@ -243,12 +260,12 @@ "metadata": {}, "outputs": [], "source": [ - "E = np.linspace(-1, -.5, 100)\n", + "E = np.linspace(-1, -0.5, 100)\n", "dE = E[1] - E[0]\n", - "PDOS = es.PDOS(E).sum((0, 2)) * dE # perform integration\n", + "PDOS = es.PDOS(E).sum((0, 2)) * dE # perform integration\n", "system.plot(axes=\"xy\", atoms_style={\"size\": normalize(PDOS, 0, 1)})\n", - "#plt.scatter(system.xyz[:, 0], system.xyz[:, 1], 500 * PDOS);\n", - "#plt.scatter(xyz_remove[0], xyz_remove[1], c='k', marker='*'); # mark the removed atom" + "# plt.scatter(system.xyz[:, 0], system.xyz[:, 1], 500 * PDOS);\n", + "# plt.scatter(xyz_remove[0], xyz_remove[1], c='k', marker='*'); # mark the removed atom" ] }, { @@ -275,10 +292,12 @@ "metadata": {}, "outputs": [], "source": [ - "band = BandStructure(H, [[0, 0, 0], [0, 0.5, 0], \n", - " [1/3, 2/3, 0], [0, 0, 0]], 400, \n", - " [r'Gamma', r'M', \n", - " r'K', r'Gamma'])" + "band = BandStructure(\n", + " H,\n", + " [[0, 0, 0], [0, 0.5, 0], [1 / 3, 2 / 3, 0], [0, 0, 0]],\n", + " 400,\n", + " [r\"Gamma\", r\"M\", r\"K\", r\"Gamma\"],\n", + ")" ] }, { @@ -332,7 +351,9 @@ "outputs": [], "source": [ "bz = MonkhorstPack(H, [35, 35, 1])\n", - "bz_average = bz.apply.average; # specify the Brillouin zone to perform an average of subsequent method calls" + "bz_average = (\n", + " bz.apply.average\n", + "); # specify the Brillouin zone to perform an average of subsequent method calls" ] }, { @@ -342,9 +363,9 @@ "outputs": [], "source": [ "E = np.linspace(-4, 4, 1000)\n", - "plt.plot(E, bz_average.eigenstate(wrap=lambda es: es.DOS(E)));\n", - "plt.xlabel('$E - E_F$ [eV]');\n", - "plt.ylabel('DOS [1/eV]');" + "plt.plot(E, bz_average.eigenstate(wrap=lambda es: es.DOS(E)))\n", + "plt.xlabel(\"$E - E_F$ [eV]\")\n", + "plt.ylabel(\"DOS [1/eV]\");" ] }, { @@ -393,8 +414,8 @@ "source": [ "r = np.linspace(0, 1.6, 700)\n", "f = 5 * np.exp(-r * 5)\n", - "print('Normalization: {}'.format(f.sum() * (r[1] - r[0])))\n", - "plt.plot(r, f);\n", + "print(\"Normalization: {}\".format(f.sum() * (r[1] - r[0])))\n", + "plt.plot(r, f)\n", "plt.ylim([0, None])\n", "orb = SphericalOrbital(1, (r, f));" ] diff --git a/docs/tutorials/tutorial_es_2.ipynb b/docs/tutorials/tutorial_es_2.ipynb index ac868cc7df..dc20bfc3d8 100644 --- a/docs/tutorials/tutorial_es_2.ipynb +++ b/docs/tutorials/tutorial_es_2.ipynb @@ -10,6 +10,7 @@ "from sisl import *\n", "import sisl.viz\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline" ] }, @@ -50,7 +51,7 @@ "metadata": {}, "outputs": [], "source": [ - "H_bp = H.copy() # an exact copy\n", + "H_bp = H.copy() # an exact copy\n", "H_bp[0, 0] = 0.1\n", "H_bp[1, 1] = -0.1" ] @@ -70,7 +71,9 @@ "metadata": {}, "outputs": [], "source": [ - "band = BandStructure(H, [[0, 0.5, 0], [1/3, 2/3, 0], [0.5, 0.5, 0]], 400, [r\"$M$\", r\"$K$\", r\"$M'$\"])" + "band = BandStructure(\n", + " H, [[0, 0.5, 0], [1 / 3, 2 / 3, 0], [0.5, 0.5, 0]], 400, [r\"$M$\", r\"$K$\", r\"$M'$\"]\n", + ")" ] }, { @@ -98,8 +101,10 @@ "metadata": {}, "outputs": [], "source": [ - "bz = MonkhorstPack(H, [41, 41, 1], displacement=[1/3, 2/3, 0], size=[.125, .125, 1])\n", - "bz_average = bz.apply.average # specify the Brillouin zone to perform an average" + "bz = MonkhorstPack(\n", + " H, [41, 41, 1], displacement=[1 / 3, 2 / 3, 0], size=[0.125, 0.125, 1]\n", + ")\n", + "bz_average = bz.apply.average # specify the Brillouin zone to perform an average" ] }, { @@ -116,9 +121,9 @@ "metadata": {}, "outputs": [], "source": [ - "plt.scatter(bz.k[:, 0], bz.k[:, 1], 2);\n", - "plt.xlabel(r'$k_x$ [$b_x$]');\n", - "plt.ylabel(r'$k_y$ [$b_y$]');" + "plt.scatter(bz.k[:, 0], bz.k[:, 1], 2)\n", + "plt.xlabel(r\"$k_x$ [$b_x$]\")\n", + "plt.ylabel(r\"$k_y$ [$b_y$]\");" ] }, { @@ -135,15 +140,23 @@ "outputs": [], "source": [ "E = np.linspace(-0.5, 0.5, 1000)\n", - "dist = get_distribution('gaussian', 0.03)\n", + "dist = get_distribution(\"gaussian\", 0.03)\n", "bz.set_parent(H)\n", - "plt.plot(E, bz_average.eigenvalue(wrap=lambda ev: ev.DOS(E, distribution=dist)), label='Graphene');\n", + "plt.plot(\n", + " E,\n", + " bz_average.eigenvalue(wrap=lambda ev: ev.DOS(E, distribution=dist)),\n", + " label=\"Graphene\",\n", + ")\n", "bz.set_parent(H_bp)\n", - "plt.plot(E, bz_average.eigenvalue(wrap=lambda ev: ev.DOS(E, distribution=dist)), label='Graphene anti');\n", + "plt.plot(\n", + " E,\n", + " bz_average.eigenvalue(wrap=lambda ev: ev.DOS(E, distribution=dist)),\n", + " label=\"Graphene anti\",\n", + ")\n", "plt.legend()\n", "plt.ylim([0, None])\n", - "plt.xlabel('$E - E_F$ [eV]');\n", - "plt.ylabel('DOS [1/eV]');" + "plt.xlabel(\"$E - E_F$ [eV]\")\n", + "plt.ylabel(\"DOS [1/eV]\");" ] }, { @@ -168,12 +181,12 @@ "# Normal vector (in units of reciprocal lattice vectors)\n", "normal = [0, 0, 1]\n", "# Origo (in units of reciprocal lattice vectors)\n", - "origin = [1/3, 2/3, 0]\n", + "origin = [1 / 3, 2 / 3, 0]\n", "circle = BrillouinZone.param_circle(H, N, kR, normal, origin)\n", - "plt.plot(circle.k[:, 0], circle.k[:, 1]);\n", - "plt.xlabel(r'$k_x$ [$b_x$]')\n", - "plt.ylabel(r'$k_y$ [$b_y$]')\n", - "plt.gca().set_aspect('equal');" + "plt.plot(circle.k[:, 0], circle.k[:, 1])\n", + "plt.xlabel(r\"$k_x$ [$b_x$]\")\n", + "plt.ylabel(r\"$k_y$ [$b_y$]\")\n", + "plt.gca().set_aspect(\"equal\");" ] }, { @@ -192,10 +205,10 @@ "outputs": [], "source": [ "k = circle.tocartesian(circle.k)\n", - "plt.plot(k[:, 0], k[:, 1]);\n", - "plt.xlabel(r'$k_x$ [1/Ang]')\n", - "plt.ylabel(r'$k_y$ [1/Ang]')\n", - "plt.gca().set_aspect('equal');" + "plt.plot(k[:, 0], k[:, 1])\n", + "plt.xlabel(r\"$k_x$ [1/Ang]\")\n", + "plt.ylabel(r\"$k_y$ [1/Ang]\")\n", + "plt.gca().set_aspect(\"equal\");" ] }, { @@ -212,13 +225,21 @@ "outputs": [], "source": [ "circle.set_parent(H)\n", - "print('Pristine graphene (0): {:.5f} rad'.format(electron.berry_phase(circle, sub=0)))\n", - "print('Pristine graphene (1): {:.5f} rad'.format(electron.berry_phase(circle, sub=1)))\n", - "print('Pristine graphene (:): {:.5f} rad'.format(electron.berry_phase(circle)))\n", + "print(\"Pristine graphene (0): {:.5f} rad\".format(electron.berry_phase(circle, sub=0)))\n", + "print(\"Pristine graphene (1): {:.5f} rad\".format(electron.berry_phase(circle, sub=1)))\n", + "print(\"Pristine graphene (:): {:.5f} rad\".format(electron.berry_phase(circle)))\n", "circle.set_parent(H_bp)\n", - "print('Anti-symmetric graphene (0): {:.5f} rad'.format(electron.berry_phase(circle, sub=0)))\n", - "print('Anti-symmetric graphene (1): {:.5f} rad'.format(electron.berry_phase(circle, sub=1)))\n", - "print('Anti-symmetric graphene (:): {:.5f} rad'.format(electron.berry_phase(circle)))" + "print(\n", + " \"Anti-symmetric graphene (0): {:.5f} rad\".format(\n", + " electron.berry_phase(circle, sub=0)\n", + " )\n", + ")\n", + "print(\n", + " \"Anti-symmetric graphene (1): {:.5f} rad\".format(\n", + " electron.berry_phase(circle, sub=1)\n", + " )\n", + ")\n", + "print(\"Anti-symmetric graphene (:): {:.5f} rad\".format(electron.berry_phase(circle)))" ] }, { @@ -240,14 +261,16 @@ "for i, kR in enumerate(kRs):\n", " circle = BrillouinZone.param_circle(H_bp, dk, kR, normal, origin)\n", " bp[0, i] = electron.berry_phase(circle, sub=0)\n", - " circle_other = BrillouinZone.param_circle(utils.mathematics.fnorm(H_bp.rcell), dk, kR, normal, origin)\n", + " circle_other = BrillouinZone.param_circle(\n", + " utils.mathematics.fnorm(H_bp.rcell), dk, kR, normal, origin\n", + " )\n", " circle.k[:, :] = circle_other.k[:, :]\n", " bp[1, i] = electron.berry_phase(circle, sub=0)\n", - "plt.plot(kRs, bp[0, :]/np.pi, label=r'1/Ang');\n", - "plt.plot(kRs, bp[1, :]/np.pi, label=r'$b_i$');\n", + "plt.plot(kRs, bp[0, :] / np.pi, label=r\"1/Ang\")\n", + "plt.plot(kRs, bp[1, :] / np.pi, label=r\"$b_i$\")\n", "plt.legend()\n", - "plt.xlabel(r'Integration radius [1/Ang]');\n", - "plt.ylabel(r'Berry phase [$\\phi/\\pi$]');" + "plt.xlabel(r\"Integration radius [1/Ang]\")\n", + "plt.ylabel(r\"Berry phase [$\\phi/\\pi$]\");" ] }, { diff --git a/docs/tutorials/tutorial_siesta_1.ipynb b/docs/tutorials/tutorial_siesta_1.ipynb index bbb76a9ca2..d23bc52fdb 100644 --- a/docs/tutorials/tutorial_siesta_1.ipynb +++ b/docs/tutorials/tutorial_siesta_1.ipynb @@ -7,7 +7,8 @@ "outputs": [], "source": [ "import os\n", - "os.chdir('siesta_1')\n", + "\n", + "os.chdir(\"siesta_1\")\n", "import numpy as np\n", "from sisl import *\n", "import sisl.viz\n", @@ -15,6 +16,7 @@ "from sisl.viz.processors.math import normalize\n", "from functools import partial\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline" ] }, @@ -38,9 +40,11 @@ "metadata": {}, "outputs": [], "source": [ - "h2o = Geometry([[0, 0, 0], [0.8, 0.6, 0], [-0.8, 0.6, 0.]], \n", - " [Atom('O'), Atom('H'), Atom('H')], \n", - " lattice=Lattice(10, origin=[-5] * 3))" + "h2o = Geometry(\n", + " [[0, 0, 0], [0.8, 0.6, 0], [-0.8, 0.6, 0.0]],\n", + " [Atom(\"O\"), Atom(\"H\"), Atom(\"H\")],\n", + " lattice=Lattice(10, origin=[-5] * 3),\n", + ")" ] }, { @@ -93,7 +97,8 @@ "metadata": {}, "outputs": [], "source": [ - "open('RUN.fdf', 'w').write(\"\"\"%include STRUCT.fdf\n", + "open(\"RUN.fdf\", \"w\").write(\n", + " \"\"\"%include STRUCT.fdf\n", "SystemLabel siesta_1\n", "PAO.BasisSize SZP\n", "MeshCutoff 250. Ry\n", @@ -101,8 +106,9 @@ "CDF.Compress 9\n", "SaveHS true\n", "SaveRho true\n", - "\"\"\")\n", - "h2o.write('STRUCT.fdf')" + "\"\"\"\n", + ")\n", + "h2o.write(\"STRUCT.fdf\")" ] }, { @@ -139,7 +145,7 @@ "metadata": {}, "outputs": [], "source": [ - "fdf = get_sile('RUN.fdf')\n", + "fdf = get_sile(\"RUN.fdf\")\n", "H = fdf.read_hamiltonian()\n", "# Create a short-hand to handle the geometry\n", "h2o = H.geometry\n", @@ -171,26 +177,28 @@ "outputs": [], "source": [ "def plot_atom(atom):\n", - " no = len(atom) # number of orbitals\n", + " no = len(atom) # number of orbitals\n", " nx = no // 4\n", " ny = no // nx\n", " if nx * ny < no:\n", " nx += 1\n", - " fig, axs = plt.subplots(nx, ny, figsize=(20, 5*nx))\n", - " fig.suptitle('Atom: {}'.format(atom.symbol), fontsize=14)\n", + " fig, axs = plt.subplots(nx, ny, figsize=(20, 5 * nx))\n", + " fig.suptitle(\"Atom: {}\".format(atom.symbol), fontsize=14)\n", + "\n", " def my_plot(i, orb):\n", " grid = orb.toGrid(atom=atom)\n", " # Also write to a cube file\n", - " grid.write('{}_{}.cube'.format(atom.symbol, orb.name()))\n", + " grid.write(\"{}_{}.cube\".format(atom.symbol, orb.name()))\n", " c, r = i // 4, (i - 4) % 4\n", " if nx == 1:\n", " ax = axs[r]\n", " else:\n", " ax = axs[c][r]\n", " ax.imshow(grid.grid[:, :, grid.shape[2] // 2])\n", - " ax.set_title(r'${}$'.format(orb.name(True)))\n", - " ax.set_xlabel(r'$x$ [Ang]')\n", - " ax.set_ylabel(r'$y$ [Ang]')\n", + " ax.set_title(r\"${}$\".format(orb.name(True)))\n", + " ax.set_xlabel(r\"$x$ [Ang]\")\n", + " ax.set_ylabel(r\"$y$ [Ang]\")\n", + "\n", " i = 0\n", " for orb in atom:\n", " my_plot(i, orb)\n", @@ -205,6 +213,8 @@ " ax = axs[c][r]\n", " fig.delaxes(ax)\n", " plt.draw()\n", + "\n", + "\n", "plot_atom(h2o.atoms[0])\n", "plot_atom(h2o.atoms[1])" ] @@ -239,11 +249,12 @@ "\n", "plots = [\n", " h2o.plot(axes=\"xy\", atoms_style={\"size\": normalize(n), \"color\": c})\n", - " for n, c in zip(h2o.apply(es.norm2(sum=False),\n", - " np.sum,\n", - " mapper=partial(h2o.a2o, all=True),\n", - " axis=1),\n", - " (\"red\", \"blue\", \"green\"))\n", + " for n, c in zip(\n", + " h2o.apply(\n", + " es.norm2(sum=False), np.sum, mapper=partial(h2o.a2o, all=True), axis=1\n", + " ),\n", + " (\"red\", \"blue\", \"green\"),\n", + " )\n", "]\n", "\n", "merge_plots(*plots, composite_method=\"subplots\", cols=2)" @@ -265,15 +276,21 @@ "outputs": [], "source": [ "def integrate(g):\n", - " print('Real space integrated wavefunction: {:.4f}'.format((np.absolute(g.grid) ** 2).sum() * g.dvolume))\n", + " print(\n", + " \"Real space integrated wavefunction: {:.4f}\".format(\n", + " (np.absolute(g.grid) ** 2).sum() * g.dvolume\n", + " )\n", + " )\n", + "\n", + "\n", "g = Grid(0.2, lattice=h2o.lattice)\n", "es.sub(0).wavefunction(g)\n", "integrate(g)\n", - "#g.write('HOMO.cube')\n", - "g.fill(0) # reset the grid values to 0\n", + "# g.write('HOMO.cube')\n", + "g.fill(0) # reset the grid values to 0\n", "es.sub(1).wavefunction(g)\n", "integrate(g)\n", - "#g.write('LUMO.cube')" + "# g.write('LUMO.cube')" ] }, { @@ -295,7 +312,7 @@ "outputs": [], "source": [ "DM = fdf.read_density_matrix()\n", - "rho = get_sile('siesta_1.nc').read_grid('Rho')" + "rho = get_sile(\"siesta_1.nc\").read_grid(\"Rho\")" ] }, { @@ -313,7 +330,11 @@ "source": [ "diff = rho * (-1)\n", "DM.density(diff)\n", - "print('Real space integrated density difference: {:.3e}'.format(diff.grid.sum() * diff.dvolume))" + "print(\n", + " \"Real space integrated density difference: {:.3e}\".format(\n", + " diff.grid.sum() * diff.dvolume\n", + " )\n", + ")" ] }, { diff --git a/docs/tutorials/tutorial_siesta_2.ipynb b/docs/tutorials/tutorial_siesta_2.ipynb index 383d09d0c1..b5541cf5b6 100644 --- a/docs/tutorials/tutorial_siesta_2.ipynb +++ b/docs/tutorials/tutorial_siesta_2.ipynb @@ -7,11 +7,13 @@ "outputs": [], "source": [ "import os\n", - "os.chdir('siesta_2')\n", + "\n", + "os.chdir(\"siesta_2\")\n", "import numpy as np\n", "from sisl import *\n", "import sisl.viz\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline" ] }, @@ -68,7 +70,8 @@ "metadata": {}, "outputs": [], "source": [ - "open('RUN.fdf', 'w').write(\"\"\"%include STRUCT.fdf\n", + "open(\"RUN.fdf\", \"w\").write(\n", + " \"\"\"%include STRUCT.fdf\n", "SystemLabel siesta_2\n", "PAO.BasisSize SZP\n", "MeshCutoff 250. Ry\n", @@ -81,8 +84,9 @@ " 1 61 1 0.\n", " 0 0 1 0.\n", "%endblock\n", - "\"\"\")\n", - "graphene.write('STRUCT.fdf')" + "\"\"\"\n", + ")\n", + "graphene.write(\"STRUCT.fdf\")" ] }, { @@ -103,7 +107,7 @@ "metadata": {}, "outputs": [], "source": [ - "fdf = get_sile('RUN.fdf')\n", + "fdf = get_sile(\"RUN.fdf\")\n", "H = fdf.read_hamiltonian()\n", "print(H)" ] @@ -126,10 +130,15 @@ "E = np.linspace(-6, 4, 500)\n", "for nk in [21, 41, 61, 81]:\n", " bz = MonkhorstPack(H, [nk, nk, 1])\n", - " plt.plot(E, bz.apply.average.eigenvalue(wrap=lambda ev: ev.DOS(E)), label='nk={}'.format(nk));\n", - "plt.xlim(E[0], E[-1]); plt.ylim(0, None)\n", - "plt.xlabel(r'$E - E_F$ [eV]')\n", - "plt.ylabel(r'DOS [1/eV]')\n", + " plt.plot(\n", + " E,\n", + " bz.apply.average.eigenvalue(wrap=lambda ev: ev.DOS(E)),\n", + " label=\"nk={}\".format(nk),\n", + " )\n", + "plt.xlim(E[0], E[-1])\n", + "plt.ylim(0, None)\n", + "plt.xlabel(r\"$E - E_F$ [eV]\")\n", + "plt.ylabel(r\"DOS [1/eV]\")\n", "plt.legend();" ] }, @@ -160,7 +169,7 @@ "orb_groups = [\n", " {\"l\": 0, \"name\": \"s\", \"color\": \"red\"},\n", " {\"l\": 1, \"m\": [-1, 1], \"name\": \"px + py\", \"color\": \"blue\"},\n", - " {\"l\": 1, \"m\": 0, \"name\": \"pz\", \"color\": \"green\"}\n", + " {\"l\": 1, \"m\": 0, \"name\": \"pz\", \"color\": \"green\"},\n", "]\n", "pdos_plot.update_inputs(groups=orb_groups)" ] @@ -181,8 +190,12 @@ "outputs": [], "source": [ "# Define the band-structure\n", - "bz = BandStructure(H, [[0] * 3, [2./3, 1./3, 0], [0.5, 0.5, 0], [1] * 3], 400, \n", - " names=[r'$\\Gamma$', r'$K$', r'$M$', r'$\\Gamma$'])" + "bz = BandStructure(\n", + " H,\n", + " [[0] * 3, [2.0 / 3, 1.0 / 3, 0], [0.5, 0.5, 0], [1] * 3],\n", + " 400,\n", + " names=[r\"$\\Gamma$\", r\"$K$\", r\"$M$\", r\"$\\Gamma$\"],\n", + ")" ] }, { @@ -235,15 +248,15 @@ "# To do this we need to find the index of the corresponding z-plane.\n", "# The Grid.index method is useful in this regard.\n", "xyz = H.geometry.xyz[0, :].copy()\n", - "xyz[2] += 1.\n", + "xyz[2] += 1.0\n", "z_idx = g.index(xyz, axis=2)\n", - "x, y = np.mgrid[:g.shape[0], :g.shape[1]]\n", + "x, y = np.mgrid[: g.shape[0], : g.shape[1]]\n", "x, y = x * g.dcell[0, 0] + y * g.dcell[1, 0], x * g.dcell[0, 1] + y * g.dcell[1, 1]\n", - "plt.contourf(x, y, g.grid[:, :, z_idx]);\n", + "plt.contourf(x, y, g.grid[:, :, z_idx])\n", "xyz = H.geometry.tile(2, 0).tile(2, 1).xyz\n", - "plt.scatter(xyz[:, 0], xyz[:, 1], 20, c='k');\n", - "plt.xlabel(r'$x$ [Ang]');\n", - "plt.ylabel(r'$y$ [Ang]');" + "plt.scatter(xyz[:, 0], xyz[:, 1], 20, c=\"k\")\n", + "plt.xlabel(r\"$x$ [Ang]\")\n", + "plt.ylabel(r\"$y$ [Ang]\");" ] }, { @@ -260,20 +273,20 @@ "outputs": [], "source": [ "_, axs = plt.subplots(1, 2, figsize=(16, 5))\n", - "es = H.eigenstate([1./2, 0, 0])\n", + "es = H.eigenstate([1.0 / 2, 0, 0])\n", "idx_valence = (es.eig > 0).nonzero()[0][0] - 1\n", "es = es.sub(idx_valence)\n", "g = Grid(0.2, dtype=np.complex128, lattice=H.geometry.lattice.tile(4, 0).tile(4, 1))\n", "es.wavefunction(g)\n", - "x, y = np.mgrid[:g.shape[0], :g.shape[1]]\n", + "x, y = np.mgrid[: g.shape[0], : g.shape[1]]\n", "x, y = x * g.dcell[0, 0] + y * g.dcell[1, 0], x * g.dcell[0, 1] + y * g.dcell[1, 1]\n", - "axs[0].contourf(x, y, g.grid[:, :, z_idx].real);\n", - "axs[1].contourf(x, y, g.grid[:, :, z_idx].imag);\n", + "axs[0].contourf(x, y, g.grid[:, :, z_idx].real)\n", + "axs[1].contourf(x, y, g.grid[:, :, z_idx].imag)\n", "xyz = H.geometry.tile(4, 0).tile(4, 1).xyz\n", "for ax in axs:\n", - " ax.scatter(xyz[:, 0], xyz[:, 1], 20, c='k');\n", - " ax.set_xlabel(r'$x$ [Ang]');\n", - " ax.set_ylabel(r'$y$ [Ang]');" + " ax.scatter(xyz[:, 0], xyz[:, 1], 20, c=\"k\")\n", + " ax.set_xlabel(r\"$x$ [Ang]\")\n", + " ax.set_ylabel(r\"$y$ [Ang]\");" ] }, { diff --git a/docs/visualization/viz_module/basic-tutorials/Demo.ipynb b/docs/visualization/viz_module/basic-tutorials/Demo.ipynb index f8dcafe17d..e4551ef55f 100644 --- a/docs/visualization/viz_module/basic-tutorials/Demo.ipynb +++ b/docs/visualization/viz_module/basic-tutorials/Demo.ipynb @@ -14,8 +14,11 @@ "outputs": [], "source": [ "import sisl\n", + "\n", "# We define the root directory where our files are\n", - "siesta_files = sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"" + "siesta_files = (\n", + " sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"\n", + ")" ] }, { @@ -86,7 +89,7 @@ "outputs": [], "source": [ "rho_file = sisl.get_sile(siesta_files / \"SrTiO3.RHO\")\n", - "rho_file.plot(axes=\"xy\", nsc=[2,1,1], smooth=True)" + "rho_file.plot(axes=\"xy\", nsc=[2, 1, 1], smooth=True)" ] }, { @@ -304,7 +307,7 @@ }, "outputs": [], "source": [ - "thumbnail_plot = rho_file.plot(axes=\"xy\", nsc=[2,1,1], smooth=True)\n", + "thumbnail_plot = rho_file.plot(axes=\"xy\", nsc=[2, 1, 1], smooth=True)\n", "\n", "if thumbnail_plot:\n", " thumbnail_plot.show(\"png\")" diff --git a/docs/visualization/viz_module/combining-plots/Intro to combining plots.ipynb b/docs/visualization/viz_module/combining-plots/Intro to combining plots.ipynb index 31db0d09dd..c73cb3f92b 100644 --- a/docs/visualization/viz_module/combining-plots/Intro to combining plots.ipynb +++ b/docs/visualization/viz_module/combining-plots/Intro to combining plots.ipynb @@ -29,9 +29,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "from sisl.viz import merge_plots" @@ -56,11 +54,15 @@ "r = np.linspace(0, 3.5, 50)\n", "f = np.exp(-r)\n", "\n", - "orb = sisl.AtomicOrbital('2pzZ', (r, f))\n", - "geom = sisl.geom.graphene(orthogonal=False, atoms=[sisl.Atom(5, orb), sisl.Atom(7, orb)])\n", + "orb = sisl.AtomicOrbital(\"2pzZ\", (r, f))\n", + "geom = sisl.geom.graphene(\n", + " orthogonal=False, atoms=[sisl.Atom(5, orb), sisl.Atom(7, orb)]\n", + ")\n", "geom = geom.move([0, 0, 5])\n", "H = sisl.Hamiltonian(geom)\n", - "H.construct([(0.1, 1.44), (0, -2.7)], )\n", + "H.construct(\n", + " [(0.1, 1.44), (0, -2.7)],\n", + ")\n", "H[0, 0] = -0.7\n", "H[1, 1] = 0.7" ] @@ -81,13 +83,15 @@ "outputs": [], "source": [ "band_structure = sisl.BandStructure(\n", - " H, \n", - " [[0, 0, 0], [0, 0.5, 0],[1/3, 2/3, 0], [0, 0, 0]], \n", + " H,\n", + " [[0, 0, 0], [0, 0.5, 0], [1 / 3, 2 / 3, 0], [0, 0, 0]],\n", " 400,\n", - " [r'Gamma', r'M',r'K', r'Gamma']\n", + " [r\"Gamma\", r\"M\", r\"K\", r\"Gamma\"],\n", ")\n", "bands_plot = band_structure.plot()\n", - "pdos_plot = H.plot.pdos(data_Erange=[-10, 10], Erange=[-10,10], kgrid=[121, 121, 1], nE=1000).split_DOS(name=\"$species\")\n", + "pdos_plot = H.plot.pdos(\n", + " data_Erange=[-10, 10], Erange=[-10, 10], kgrid=[121, 121, 1], nE=1000\n", + ").split_DOS(name=\"$species\")\n", "\n", "plots = [bands_plot, pdos_plot]" ] @@ -238,7 +242,9 @@ "source": [ "merged_plot = merge_plots(*plots, composite_method=\"multiple_x\")\n", "\n", - "merge_plots(merged_plot, bands_plot, composite_method=\"subplots\", cols=2, backend=\"plotly\")" + "merge_plots(\n", + " merged_plot, bands_plot, composite_method=\"subplots\", cols=2, backend=\"plotly\"\n", + ")" ] }, { @@ -270,12 +276,18 @@ "# Do 1 by 1 from 1 to 12 and then in steps of 5 from 15 to 90.\n", "ks = [*np.arange(1, 12), *np.arange(15, 90, 5)]\n", "\n", - "# Generate all plots. \n", + "# Generate all plots.\n", "# We use the scatter trace instead of a line because it looks better in animations :)\n", "pdos_plots = [\n", " H.plot.pdos(\n", - " data_Erange=[-10, 10], Erange=[-10,10], kgrid=[k, k, 1], nE=1000, line_mode=\"scatter\", line_scale=2\n", - " ).split_DOS(name=\"$species\") for k in ks\n", + " data_Erange=[-10, 10],\n", + " Erange=[-10, 10],\n", + " kgrid=[k, k, 1],\n", + " nE=1000,\n", + " line_mode=\"scatter\",\n", + " line_scale=2,\n", + " ).split_DOS(name=\"$species\")\n", + " for k in ks\n", "]" ] }, diff --git a/docs/visualization/viz_module/diy/Adding new backends.ipynb b/docs/visualization/viz_module/diy/Adding new backends.ipynb index 6cfc05574b..e7f335950a 100644 --- a/docs/visualization/viz_module/diy/Adding new backends.ipynb +++ b/docs/visualization/viz_module/diy/Adding new backends.ipynb @@ -24,9 +24,11 @@ "\n", "geom = sisl.geom.graphene(orthogonal=True)\n", "H = sisl.Hamiltonian(geom)\n", - "H.construct([(0.1, 1.44), (0, -2.7)], )\n", + "H.construct(\n", + " [(0.1, 1.44), (0, -2.7)],\n", + ")\n", "\n", - "band_struct = sisl.BandStructure(H, [[0,0,0], [0.5,0,0]], 10, [\"Gamma\", \"X\"])" + "band_struct = sisl.BandStructure(H, [[0, 0, 0], [0.5, 0, 0]], 10, [\"Gamma\", \"X\"])" ] }, { @@ -83,9 +85,7 @@ "cell_type": "code", "execution_count": null, "id": "ae3dcf66-fd10-47bb-a706-b24b4bb2ba40", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "help(Figure)" @@ -109,20 +109,21 @@ "outputs": [], "source": [ "import numpy as np\n", + "\n", + "\n", "class TextFigure(Figure):\n", - " \n", " def _init_figure(self, *args, **kwargs):\n", " self.text = \"\"\n", - " \n", + "\n", " def clear(self):\n", " self.text = \"\"\n", - " \n", + "\n", " def draw_line(self, x, y, name, **kwargs):\n", " self.text += f\"\\nLINE: {name}\\n{np.array(x)}\\n{np.array(y)}\"\n", - " \n", + "\n", " def draw_scatter(self, x, y, name, **kwargs):\n", " self.text += f\"\\nSCATTER: {name}\\n{np.array(x)}\\n{np.array(y)}\"\n", - " \n", + "\n", " def show(self):\n", " print(self.text)\n", "\n", diff --git a/docs/visualization/viz_module/diy/Building a new plot.ipynb b/docs/visualization/viz_module/diy/Building a new plot.ipynb index b257d23cba..b925dccdfa 100644 --- a/docs/visualization/viz_module/diy/Building a new plot.ipynb +++ b/docs/visualization/viz_module/diy/Building a new plot.ipynb @@ -86,7 +86,6 @@ "outputs": [], "source": [ "def a_cool_plot(color=\"red\", backend=\"plotly\"):\n", - "\n", " action = plot_actions.draw_line(x=[1, 2], y=[3, 4], line={\"color\": color})\n", "\n", " return get_figure(backend=backend, plot_actions=[action])" @@ -219,7 +218,6 @@ "outputs": [], "source": [ "class CoolPlot(Plot):\n", - "\n", " # The function that this workflow will execute\n", " function = staticmethod(a_cool_plot)\n", "\n", diff --git a/docs/visualization/viz_module/showcase/BandsPlot.ipynb b/docs/visualization/viz_module/showcase/BandsPlot.ipynb index e7a469e051..5f3d990b89 100644 --- a/docs/visualization/viz_module/showcase/BandsPlot.ipynb +++ b/docs/visualization/viz_module/showcase/BandsPlot.ipynb @@ -29,8 +29,11 @@ "source": [ "import sisl\n", "import sisl.viz\n", + "\n", "# This is just for convenience to retreive files\n", - "siesta_files = sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"" + "siesta_files = (\n", + " sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"\n", + ")" ] }, { @@ -46,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot = sisl.get_sile( siesta_files / \"SrTiO3.bands\").plot()" + "bands_plot = sisl.get_sile(siesta_files / \"SrTiO3.bands\").plot()" ] }, { @@ -203,6 +206,7 @@ " print(data)\n", " return \"green\"\n", "\n", + "\n", "bands_plot.update_inputs(bands_style={\"color\": color})" ] }, @@ -223,27 +227,29 @@ "source": [ "def gradient(data):\n", " \"\"\"Function that computes the absolute value of dE/dk.\n", - " \n", + "\n", " This returns a two dimensional array (gradient depends on k and band)\n", " \"\"\"\n", " return abs(data.E.differentiate(\"k\"))\n", "\n", + "\n", "def band_closeness_to_Ef(data):\n", " \"\"\"Computes how close one band is to the fermi level.\n", - " \n", + "\n", " This returns a one dimensional array (distance depends only on band)\n", " \"\"\"\n", " dist_from_Ef = abs(data.E).min(\"k\")\n", - " \n", - " return (1 / dist_from_Ef ** 0.4) * 5\n", + "\n", + " return (1 / dist_from_Ef**0.4) * 5\n", + "\n", "\n", "# Now we are going to set the width of the band according to the distance from the fermi level\n", "# and the color according to the gradient. We are going to set the colorscale also, instead of using\n", "# the default one.\n", "bands_plot.update_inputs(\n", - " bands_style={\"width\": band_closeness_to_Ef, \"color\": gradient}, \n", + " bands_style={\"width\": band_closeness_to_Ef, \"color\": gradient},\n", " colorscale=\"temps\",\n", - " Erange=[-10, 10]\n", + " Erange=[-10, 10],\n", ")" ] }, @@ -280,7 +286,7 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.update_inputs(gap=True, gap_color=\"green\", Erange=[-10,10])" + "bands_plot.update_inputs(gap=True, gap_color=\"green\", Erange=[-10, 10])" ] }, { @@ -374,7 +380,7 @@ "metadata": {}, "outputs": [], "source": [ - "bands_plot.nodes['bands_data'].get().k.axis" + "bands_plot.nodes[\"bands_data\"].get().k.axis" ] }, { @@ -390,7 +396,7 @@ "metadata": {}, "outputs": [], "source": [ - "axis_info = bands_plot.nodes['bands_data'].get().k.axis\n", + "axis_info = bands_plot.nodes[\"bands_data\"].get().k.axis\n", "\n", "gap_k = None\n", "for val, label in zip(axis_info[\"tickvals\"], axis_info[\"ticktext\"]):\n", @@ -433,7 +439,10 @@ "outputs": [], "source": [ "import sisl\n", - "siesta_files = sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"" + "\n", + "siesta_files = (\n", + " sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"\n", + ")" ] }, { @@ -459,10 +468,17 @@ "metadata": {}, "outputs": [], "source": [ - "band_struct = sisl.BandStructure(H, points=[[1./2, 0., 0.], [0., 0., 0.],\n", - " [1./3, 1./3, 0.], [1./2, 0., 0.]],\n", - " divisions=301,\n", - " names=['M', r'Gamma', 'K', 'M'])" + "band_struct = sisl.BandStructure(\n", + " H,\n", + " points=[\n", + " [1.0 / 2, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0],\n", + " [1.0 / 3, 1.0 / 3, 0.0],\n", + " [1.0 / 2, 0.0, 0.0],\n", + " ],\n", + " divisions=301,\n", + " names=[\"M\", r\"Gamma\", \"K\", \"M\"],\n", + ")" ] }, { @@ -478,7 +494,7 @@ "metadata": {}, "outputs": [], "source": [ - "spin_texture_plot = band_struct.plot.bands(Erange=[-2,2])\n", + "spin_texture_plot = band_struct.plot.bands(Erange=[-2, 2])\n", "spin_texture_plot" ] }, @@ -499,12 +515,10 @@ "source": [ "from sisl.viz.data_sources import SpinMoment\n", "\n", - "spin_texture_plot.update_inputs(\n", - " bands_style={\"color\": SpinMoment(\"x\"), \"width\": 3}\n", - ")\n", + "spin_texture_plot.update_inputs(bands_style={\"color\": SpinMoment(\"x\"), \"width\": 3})\n", "\n", "# We hide the legend so that the colorbar can be easily seen.\n", - "spin_texture_plot.update_layout(showlegend=False) " + "spin_texture_plot.update_layout(showlegend=False)" ] }, { @@ -525,6 +539,7 @@ " print(data)\n", " return \"green\"\n", "\n", + "\n", "spin_texture_plot.update_inputs(bands_style={\"color\": color})" ] }, diff --git a/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb b/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb index 14811a063f..590e7469b9 100644 --- a/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb +++ b/docs/visualization/viz_module/showcase/FatbandsPlot.ipynb @@ -75,9 +75,12 @@ "metadata": {}, "outputs": [], "source": [ - "band = sisl.BandStructure(H, [[0., 0.], [2./3, 1./3],\n", - " [1./2, 1./2], [1., 1.]], 301,\n", - " [r'Gamma', 'K', 'M', r'Gamma'])" + "band = sisl.BandStructure(\n", + " H,\n", + " [[0.0, 0.0], [2.0 / 3, 1.0 / 3], [1.0 / 2, 1.0 / 2], [1.0, 1.0]],\n", + " 301,\n", + " [r\"Gamma\", \"K\", \"M\", r\"Gamma\"],\n", + ")" ] }, { @@ -135,10 +138,12 @@ "metadata": {}, "outputs": [], "source": [ - "fatbands.update_inputs(groups=[\n", - " {\"species\": \"N\", \"color\": \"blue\", \"name\": \"Nitrogen\"},\n", - " {\"species\": \"B\", \"color\": \"red\", \"name\": \"Boron\"}\n", - "])" + "fatbands.update_inputs(\n", + " groups=[\n", + " {\"species\": \"N\", \"color\": \"blue\", \"name\": \"Nitrogen\"},\n", + " {\"species\": \"B\", \"color\": \"red\", \"name\": \"Boron\"},\n", + " ]\n", + ")" ] }, { diff --git a/docs/visualization/viz_module/showcase/GeometryPlot.ipynb b/docs/visualization/viz_module/showcase/GeometryPlot.ipynb index 7c8edda812..d70bf6bb5b 100644 --- a/docs/visualization/viz_module/showcase/GeometryPlot.ipynb +++ b/docs/visualization/viz_module/showcase/GeometryPlot.ipynb @@ -112,7 +112,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=\"x\",)" + "plot.update_inputs(\n", + " axes=\"x\",\n", + ")" ] }, { @@ -236,7 +238,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=[[1,1,0], [1, -1, 0]])" + "plot.update_inputs(axes=[[1, 1, 0], [1, -1, 0]])" ] }, { @@ -252,7 +254,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=[[1,1,0], [2, -2, 0]])" + "plot.update_inputs(axes=[[1, 1, 0], [2, -2, 0]])" ] }, { @@ -268,7 +270,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=[\"x\", [1,1,0]])" + "plot.update_inputs(axes=[\"x\", [1, 1, 0]])" ] }, { @@ -389,8 +391,8 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(atoms=[1,2,3,4,5], show_atoms=True, show_cell=\"axes\")\n", - "#show_cell accepts \"box\", \"axes\" and False" + "plot.update_inputs(atoms=[1, 2, 3, 4, 5], show_atoms=True, show_cell=\"axes\")\n", + "# show_cell accepts \"box\", \"axes\" and False" ] }, { @@ -519,7 +521,7 @@ "plot.update_inputs(\n", " atoms_style=[\n", " {\"color\": \"green\", \"size\": [0.6, 0.8], \"opacity\": [1, 0.3]},\n", - " {\"atoms\": [0,1], \"color\": \"orange\"}\n", + " {\"atoms\": [0, 1], \"color\": \"orange\"},\n", " ]\n", ")" ] @@ -543,7 +545,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(atoms_style=[{\"atoms\": [0,1], \"color\": \"orange\"}])" + "plot.update_inputs(atoms_style=[{\"atoms\": [0, 1], \"color\": \"orange\"}])" ] }, { @@ -559,10 +561,12 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(atoms_style=[\n", - " {\"atoms\": {\"fx\": (None, 0.4)}, \"color\": \"orange\"},\n", - " {\"atoms\": sisl.geom.AtomOdd(), \"opacity\":0.3},\n", - "])" + "plot.update_inputs(\n", + " atoms_style=[\n", + " {\"atoms\": {\"fx\": (None, 0.4)}, \"color\": \"orange\"},\n", + " {\"atoms\": sisl.geom.AtomOdd(), \"opacity\": 0.3},\n", + " ]\n", + ")" ] }, { @@ -581,12 +585,15 @@ "outputs": [], "source": [ "# Get the Y coordinates\n", - "y = plot.geometry.xyz[:,1]\n", + "y = plot.geometry.xyz[:, 1]\n", "# And color atoms according to it\n", - "plot.update_inputs(atoms_style=[\n", - " {\"color\": y}, \n", - " {\"atoms\": sisl.geom.AtomOdd(), \"opacity\":0.3},\n", - "], atoms_colorscale=\"viridis\")" + "plot.update_inputs(\n", + " atoms_style=[\n", + " {\"color\": y},\n", + " {\"atoms\": sisl.geom.AtomOdd(), \"opacity\": 0.3},\n", + " ],\n", + " atoms_colorscale=\"viridis\",\n", + ")" ] }, { @@ -627,7 +634,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=\"yx\", bonds_style={\"color\": \"orange\", \"width\": 5, \"opacity\": 0.5}).get()" + "plot.update_inputs(\n", + " axes=\"yx\", bonds_style={\"color\": \"orange\", \"width\": 5, \"opacity\": 0.5}\n", + ").get()" ] }, { @@ -644,7 +653,10 @@ "outputs": [], "source": [ "plot.update_inputs(\n", - " bonds_style={\"color\": ['blue'] * 10 + ['orange'] * 19, \"width\": np.linspace(3, 7, 29)}\n", + " bonds_style={\n", + " \"color\": [\"blue\"] * 10 + [\"orange\"] * 19,\n", + " \"width\": np.linspace(3, 7, 29),\n", + " }\n", ")" ] }, @@ -667,6 +679,7 @@ " # We are going to color the bonds based on how far they go in the Y axis\n", " return abs(geometry[bonds[:, 0], 1] - geometry[bonds[:, 1], 1])\n", "\n", + "\n", "plot.update_inputs(bonds_style={\"color\": color_bonds, \"width\": 5})" ] }, @@ -687,7 +700,10 @@ "source": [ "from sisl.viz.data_sources import BondLength, BondDataFromMatrix, BondRandom\n", "\n", - "plot.update_inputs(axes=\"yx\", bonds_style={\"color\": BondRandom(), \"width\": BondRandom() * 10, \"opacity\": 0.5})" + "plot.update_inputs(\n", + " axes=\"yx\",\n", + " bonds_style={\"color\": BondRandom(), \"width\": BondRandom() * 10, \"opacity\": 0.5},\n", + ")" ] }, { @@ -745,7 +761,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(arrows={\"data\": [0,0,2], \"name\": \"Upwards force\"})" + "plot.update_inputs(arrows={\"data\": [0, 0, 2], \"name\": \"Upwards force\"})" ] }, { @@ -761,8 +777,10 @@ "metadata": {}, "outputs": [], "source": [ - "forces = np.linspace([0,0,2], [0,3,1], 18)\n", - "plot.update_inputs(arrows={\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4})" + "forces = np.linspace([0, 0, 2], [0, 3, 1], 18)\n", + "plot.update_inputs(\n", + " arrows={\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4}\n", + ")" ] }, { @@ -778,10 +796,12 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(arrows=[\n", - " {\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4},\n", - " {\"data\": [0,0,2], \"name\": \"Upwards force\", \"color\": \"red\"}\n", - "])" + "plot.update_inputs(\n", + " arrows=[\n", + " {\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4},\n", + " {\"data\": [0, 0, 2], \"name\": \"Upwards force\", \"color\": \"red\"},\n", + " ]\n", + ")" ] }, { @@ -797,10 +817,17 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(arrows=[\n", - " {\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4},\n", - " {\"atoms\": {\"fy\": (0, 0.5)} ,\"data\": [0,0,2], \"name\": \"Upwards force\", \"color\": \"red\"}\n", - "])" + "plot.update_inputs(\n", + " arrows=[\n", + " {\"data\": forces, \"name\": \"Force\", \"color\": \"orange\", \"width\": 4},\n", + " {\n", + " \"atoms\": {\"fy\": (0, 0.5)},\n", + " \"data\": [0, 0, 2],\n", + " \"name\": \"Upwards force\",\n", + " \"color\": \"red\",\n", + " },\n", + " ]\n", + ")" ] }, { @@ -851,7 +878,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=\"xyz\", nsc=[2,1,1])" + "plot.update_inputs(axes=\"xyz\", nsc=[2, 1, 1])" ] }, { diff --git a/docs/visualization/viz_module/showcase/GridPlot.ipynb b/docs/visualization/viz_module/showcase/GridPlot.ipynb index b9a924db5a..87c24896c0 100644 --- a/docs/visualization/viz_module/showcase/GridPlot.ipynb +++ b/docs/visualization/viz_module/showcase/GridPlot.ipynb @@ -32,16 +32,17 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "import sisl\n", "import sisl.viz\n", "import numpy as np\n", + "\n", "# This is just for convenience to retreive files\n", - "siesta_files = sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"" + "siesta_files = (\n", + " sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"\n", + ")" ] }, { @@ -326,7 +327,11 @@ "outputs": [], "source": [ "plot.update_inputs(\n", - " axes=\"xyz\", isos=[{\"frac\":frac, \"opacity\": frac/2, \"color\": \"green\"} for frac in np.linspace(0.1, 0.8, 20)],\n", + " axes=\"xyz\",\n", + " isos=[\n", + " {\"frac\": frac, \"opacity\": frac / 2, \"color\": \"green\"}\n", + " for frac in np.linspace(0.1, 0.8, 20)\n", + " ],\n", ")" ] }, @@ -407,7 +412,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(z_range=[1,3])" + "plot.update_inputs(z_range=[1, 3])" ] }, { @@ -450,7 +455,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(transforms=[\"sin\", abs], crange=None) \n", + "plot.update_inputs(transforms=[\"sin\", abs], crange=None)\n", "# If a string is provided with no module, it will be interpreted as a numpy function\n", "# Therefore \"sin\" == \"numpy.sin\" and abs != \"abs\" == \"numpy.abs\"" ] @@ -470,7 +475,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(nsc=[1,3,1])" + "plot.update_inputs(nsc=[1, 3, 1])" ] }, { diff --git a/docs/visualization/viz_module/showcase/PdosPlot.ipynb b/docs/visualization/viz_module/showcase/PdosPlot.ipynb index 13205c7cdf..e265da4d07 100644 --- a/docs/visualization/viz_module/showcase/PdosPlot.ipynb +++ b/docs/visualization/viz_module/showcase/PdosPlot.ipynb @@ -24,15 +24,16 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "import sisl\n", "import sisl.viz\n", + "\n", "# This is just for convenience to retreive files\n", - "siesta_files = sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"" + "siesta_files = (\n", + " sisl._environ.get_environ_variable(\"SISL_FILES_TESTS\") / \"sisl\" / \"io\" / \"siesta\"\n", + ")" ] }, { @@ -48,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot = sisl.get_sile(siesta_files / \"SrTiO3.PDOS\").plot(Erange=[-10,10])" + "plot = sisl.get_sile(siesta_files / \"SrTiO3.PDOS\").plot(Erange=[-10, 10])" ] }, { @@ -92,12 +93,19 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(groups=[{\"name\": \"My first PDOS (Oxygen)\", \"species\": [\"O\"], \"n\": 2, \"l\": 1}])\n", + "plot.update_inputs(\n", + " groups=[{\"name\": \"My first PDOS (Oxygen)\", \"species\": [\"O\"], \"n\": 2, \"l\": 1}]\n", + ")\n", "# or (it's equivalent)\n", - "plot.update_inputs(groups=[{\n", - " \"name\": \"My first PDOS (Oxygen)\", \"species\": [\"O\"],\n", - " \"orbitals\": [\"2pzZ1\", \"2pzZ2\", \"2pxZ1\", \"2pxZ2\", \"2pyZ1\", \"2pyZ2\"]\n", - "}])" + "plot.update_inputs(\n", + " groups=[\n", + " {\n", + " \"name\": \"My first PDOS (Oxygen)\",\n", + " \"species\": [\"O\"],\n", + " \"orbitals\": [\"2pzZ1\", \"2pzZ2\", \"2pxZ1\", \"2pxZ2\", \"2pyZ1\", \"2pyZ2\"],\n", + " }\n", + " ]\n", + ")" ] }, { @@ -113,11 +121,26 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(groups=[\n", - " {\"name\": \"Oxygen\", \"species\": [\"O\"], \"color\": \"darkred\", \"dash\": \"dash\", \"reduce\": \"mean\"},\n", - " {\"name\": \"Titanium\", \"species\": [\"Ti\"], \"color\": \"gray\", \"size\": 3, \"reduce\": \"mean\"},\n", - " {\"name\": \"Sr\", \"species\": [\"Sr\"], \"color\": \"green\", \"reduce\": \"mean\"},\n", - "], Erange=[-5, 5])" + "plot.update_inputs(\n", + " groups=[\n", + " {\n", + " \"name\": \"Oxygen\",\n", + " \"species\": [\"O\"],\n", + " \"color\": \"darkred\",\n", + " \"dash\": \"dash\",\n", + " \"reduce\": \"mean\",\n", + " },\n", + " {\n", + " \"name\": \"Titanium\",\n", + " \"species\": [\"Ti\"],\n", + " \"color\": \"gray\",\n", + " \"size\": 3,\n", + " \"reduce\": \"mean\",\n", + " },\n", + " {\"name\": \"Sr\", \"species\": [\"Sr\"], \"color\": \"green\", \"reduce\": \"mean\"},\n", + " ],\n", + " Erange=[-5, 5],\n", + ")" ] }, { @@ -138,11 +161,13 @@ "# Let's import the AtomZ and AtomOdd categories just to play with them\n", "from sisl.geom import AtomZ, AtomOdd\n", "\n", - "plot.update_inputs(groups=[\n", - " {\"atoms\": [0,1], \"name\": \"Atoms 0 and 1\"},\n", - " {\"atoms\": {\"Z\": 8}, \"name\": \"Atoms with Z=8\"},\n", - " {\"atoms\": AtomZ(8) & ~ AtomOdd(), \"name\": \"Oxygens with even indices\"}\n", - "])" + "plot.update_inputs(\n", + " groups=[\n", + " {\"atoms\": [0, 1], \"name\": \"Atoms 0 and 1\"},\n", + " {\"atoms\": {\"Z\": 8}, \"name\": \"Atoms with Z=8\"},\n", + " {\"atoms\": AtomZ(8) & ~AtomOdd(), \"name\": \"Oxygens with even indices\"},\n", + " ]\n", + ")" ] }, { @@ -216,7 +241,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.split_DOS(on=\"atoms\", exclude=[1,3])" + "plot.split_DOS(on=\"atoms\", exclude=[1, 3])" ] }, { @@ -232,7 +257,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.split_DOS(on=\"atoms\", only=[0,2])" + "plot.split_DOS(on=\"atoms\", only=[0, 2])" ] }, { diff --git a/docs/visualization/viz_module/showcase/SitesPlot.ipynb b/docs/visualization/viz_module/showcase/SitesPlot.ipynb index 4ca58075eb..e0469231be 100644 --- a/docs/visualization/viz_module/showcase/SitesPlot.ipynb +++ b/docs/visualization/viz_module/showcase/SitesPlot.ipynb @@ -53,12 +53,8 @@ "\n", "# Create the circle\n", "bz = sisl.BrillouinZone.param_circle(\n", - " g,\n", - " kR=0.0085,\n", - " origin= [0.0, 0.0, 0.0],\n", - " normal= [0.0, 0.0, 1.0],\n", - " N_or_dk=25,\n", - " loop=True)" + " g, kR=0.0085, origin=[0.0, 0.0, 0.0], normal=[0.0, 0.0, 1.0], N_or_dk=25, loop=True\n", + ")" ] }, { @@ -76,7 +72,7 @@ "source": [ "data = np.zeros((len(bz), 3))\n", "\n", - "data[:, 0] = - bz.k[:, 1]\n", + "data[:, 0] = -bz.k[:, 1]\n", "data[:, 1] = bz.k[:, 0]" ] }, @@ -95,8 +91,10 @@ "source": [ "# Plot k points as sites\n", "bz.plot.sites(\n", - " axes=\"xy\", drawing_mode=\"line\", sites_style={\"color\": \"black\", \"size\": 2},\n", - " arrows={\"data\": data, \"color\": \"red\", \"width\": 3, \"name\": \"Force\"}\n", + " axes=\"xy\",\n", + " drawing_mode=\"line\",\n", + " sites_style={\"color\": \"black\", \"size\": 2},\n", + " arrows={\"data\": data, \"color\": \"red\", \"width\": 3, \"name\": \"Force\"},\n", ")" ] }, diff --git a/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb b/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb index 37190f3623..9b1018c138 100644 --- a/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb +++ b/docs/visualization/viz_module/showcase/WavefunctionPlot.ipynb @@ -61,11 +61,13 @@ "r = np.linspace(0, 3.5, 50)\n", "f = np.exp(-r)\n", "\n", - "orb = sisl.AtomicOrbital('2pzZ', (r, f))\n", + "orb = sisl.AtomicOrbital(\"2pzZ\", (r, f))\n", "geom = sisl.geom.graphene(orthogonal=True, atoms=sisl.Atom(6, orb))\n", "geom = geom.move([0, 0, 5])\n", "H = sisl.Hamiltonian(geom)\n", - "H.construct([(0.1, 1.44), (0, -2.7)], )" + "H.construct(\n", + " [(0.1, 1.44), (0, -2.7)],\n", + ")" ] }, { @@ -142,7 +144,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=\"xy\", transforms=[\"square\"]) # by default grid_prec is 0.2 Ang" + "plot.update_inputs(axes=\"xy\", transforms=[\"square\"]) # by default grid_prec is 0.2 Ang" ] }, { @@ -198,10 +200,14 @@ "metadata": {}, "outputs": [], "source": [ - "plot.update_inputs(axes=\"xyz\", nsc=[2,2,1], grid_prec=0.1, transforms=[],\n", + "plot.update_inputs(\n", + " axes=\"xyz\",\n", + " nsc=[2, 2, 1],\n", + " grid_prec=0.1,\n", + " transforms=[],\n", " isos=[\n", " {\"val\": -0.07, \"opacity\": 1, \"color\": \"salmon\"},\n", - " {\"val\": 0.07, \"opacity\": 0.7, \"color\": \"blue\"}\n", + " {\"val\": 0.07, \"opacity\": 0.7, \"color\": \"blue\"},\n", " ],\n", " geom_kwargs={\"atoms_style\": dict(color=[\"orange\", \"red\", \"green\", \"pink\"])},\n", ")" diff --git a/docs/visualization/viz_module/showcase/_template/create.py b/docs/visualization/viz_module/showcase/_template/create.py index 0fd0ea5a8f..6546deee89 100644 --- a/docs/visualization/viz_module/showcase/_template/create.py +++ b/docs/visualization/viz_module/showcase/_template/create.py @@ -14,7 +14,7 @@ def create_showcase_nb(cls, force=False): Parameters ----------- cls: str - the name of the class that you want to + the name of the class that you want to """ if cls not in [c.__name__ for c in get_plot_classes()]: message = f"We didn't find a plot class with the name '{cls}'" @@ -30,8 +30,8 @@ def create_showcase_nb(cls, force=False): with open(Path(__file__).parent.parent / f"{cls}.ipynb", "w") as f: f.write(lines.replace("<$plotclass$>", cls)) -if __name__ == "__main__": +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-c", "--cls", type=str, required=True) diff --git a/docs/visualization/viz_module/tests/test_tutorials.py b/docs/visualization/viz_module/tests/test_tutorials.py index dfaf715278..8fa7a8b734 100644 --- a/docs/visualization/viz_module/tests/test_tutorials.py +++ b/docs/visualization/viz_module/tests/test_tutorials.py @@ -11,33 +11,43 @@ def _notebook_run(path): """Execute a notebook via nbconvert and collect output. - :returns (parsed nb object, execution errors) + :returns (parsed nb object, execution errors) """ dirname, __ = os.path.split(path) os.chdir(dirname) with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout: - args = ["jupyter", "nbconvert", "--to", "notebook", "--execute", - "--ExecutePreprocessor.timeout=60", - "--output", fout.name, path] + args = [ + "jupyter", + "nbconvert", + "--to", + "notebook", + "--execute", + "--ExecutePreprocessor.timeout=60", + "--output", + fout.name, + path, + ] subprocess.check_call(args) fout.seek(0) nb = nbformat.read(fout, nbformat.current_nbformat) - errors = [output for cell in nb.cells if "outputs" in cell - for output in cell["outputs"]\ - if output.output_type == "error"] + errors = [ + output + for cell in nb.cells + if "outputs" in cell + for output in cell["outputs"] + if output.output_type == "error" + ] return nb, errors class NotebookTester: - - path = '' + path = "" generated = [] def test_ipynb(self): - # Check that the notebook has ran without errors nb, errors = _notebook_run(self.path) assert errors == [] @@ -50,25 +60,24 @@ def test_ipynb(self): os.remove(path) -tut_root = Path(__file__)/ "basic-tutorials" +tut_root = Path(__file__) / "basic-tutorials" class TestDemo(NotebookTester): - path = os.path.join(tut_root, "Demo.ipynb") - generated = [os.path.join(tut_root, file_name) for file_name in ("From_animated.plot", "From_animation.plot")] + generated = [ + os.path.join(tut_root, file_name) + for file_name in ("From_animated.plot", "From_animation.plot") + ] def test_ipynb(self): - super().test_ipynb() class TestDIY(NotebookTester): - path = os.path.join(tut_root, "DIY.ipynb") class TestGUISession(NotebookTester): - path = os.path.join(tut_root, "GUI with Python Demo.ipynb") diff --git a/examples/ex_01.py b/examples/ex_01.py index 738f0b4ba1..ac5f5ab4c6 100644 --- a/examples/ex_01.py +++ b/examples/ex_01.py @@ -12,7 +12,7 @@ bond = 1.42 # Construct the atom with the appropriate orbital range # Note the 0.01 which is for numerical accuracy. -C = sisl.Atom(6, R = bond + 0.01) +C = sisl.Atom(6, R=bond + 0.01) # Create graphene unit-cell gr = sisl.geom.graphene(bond, C) @@ -23,9 +23,9 @@ for ia in gr: idx_a = gr.close(ia, R) # On-site - H[ia, idx_a[0]] = 0. + H[ia, idx_a[0]] = 0.0 # Nearest neighbour hopping H[ia, idx_a[1]] = -2.7 # Calculate eigenvalues at K-point -print(H.eigh([2./3, 1./3, 0.])) +print(H.eigh([2.0 / 3, 1.0 / 3, 0.0])) diff --git a/examples/ex_02.py b/examples/ex_02.py index df1a268572..b577068274 100644 --- a/examples/ex_02.py +++ b/examples/ex_02.py @@ -13,7 +13,7 @@ bond = 1.42 # Construct the atom with the appropriate orbital range # Note the 0.01 which is for numerical accuracy. -C = sisl.Atom(6, R = bond + 0.01) +C = sisl.Atom(6, R=bond + 0.01) # Create graphene unit-cell gr = sisl.geom.graphene(bond, C) @@ -22,8 +22,8 @@ # Create function to be passed to the construct method. # This method is *much* faster for large scale simulations. -func = H.create_construct([0.1 * bond, bond + 0.01], [0., -2.7]) +func = H.create_construct([0.1 * bond, bond + 0.01], [0.0, -2.7]) H.construct(func) # Calculate eigenvalues at K-point -print(H.eigh([2./3, 1./3, 0.])) +print(H.eigh([2.0 / 3, 1.0 / 3, 0.0])) diff --git a/examples/ex_03.py b/examples/ex_03.py index 58c5006b49..cae245721c 100644 --- a/examples/ex_03.py +++ b/examples/ex_03.py @@ -11,8 +11,9 @@ import sisl -with open('zz.gin', 'w') as f: - f.write("""opti conv dist full nosymmetry phon dynamical_matrix nod3 +with open("zz.gin", "w") as f: + f.write( + """opti conv dist full nosymmetry phon dynamical_matrix nod3 output she cutd 3.0 @@ -149,11 +150,13 @@ C core 13.49000 18.44634 0.00000 0 1 0 1 1 1 C core 14.91000 18.44634 0.00000 0 1 0 1 1 1 -brenner""") +brenner""" + ) # Create PHtrans input -with open('ZZ.fdf', 'w') as f: - f.write("""SystemLabel ZZ +with open("ZZ.fdf", "w") as f: + f.write( + """SystemLabel ZZ TBT.DOS.Gf T @@ -178,22 +181,23 @@ semi-inf-direction +a2 electrode-position end -1 %endblock -""") +""" + ) import os -if not os.path.exists('zz.gout'): +if not os.path.exists("zz.gout"): raise ValueError("zz.gin has not been runned by GULP") -print('Reading output') -gout = sisl.get_sile('zz.gout') +print("Reading output") +gout = sisl.get_sile("zz.gout") # Correct what to read from the gulp output gout.set_lattice_key("Cartesian lattice vectors") # Selectively decide whether you want to read the dynamical # matrix from the GULP output file or from the # FORCE_CONSTANTS_2ND file. -order = ['got'] # GULP output file -#order = ['FC'] # FORCE_CONSTANTS_2ND file +order = ["got"] # GULP output file +# order = ['FC'] # FORCE_CONSTANTS_2ND file dyn = gout.read_dynamical_matrix(order=order) @@ -203,7 +207,7 @@ dyn.apply_newton() dev = dyn.untile(4, 0) -dev.write('DEVICE_zz.nc') +dev.write("DEVICE_zz.nc") el = dev.untile(4, 1) -el.write('ELEC_zz.nc') +el.write("ELEC_zz.nc") diff --git a/src/sisl/__init__.py b/src/sisl/__init__.py index c9a3d02cd1..b3c4a57de2 100644 --- a/src/sisl/__init__.py +++ b/src/sisl/__init__.py @@ -44,6 +44,7 @@ """ import logging import datetime + year = datetime.datetime.now().year # instantiate the logger, but we will not use it here... @@ -53,6 +54,7 @@ __license__ = "MPL-2.0" import sisl._version as _version + __version__ = _version.version __version_tuple__ = _version.version_tuple __bibtex__ = f"""# BibTeX information if people wish to cite @@ -96,6 +98,7 @@ # Import numerical constants (they required unit) import sisl.constant as constant + # To make it easier to type ;) C = constant @@ -136,9 +139,16 @@ # that sisl is made of. import sisl.io as io from .io.sile import ( - add_sile, get_sile_class, get_sile, - get_siles, get_sile_rules, SileError, - BaseSile, Sile, SileCDF, SileBin + add_sile, + get_sile_class, + get_sile, + get_siles, + get_sile_rules, + SileError, + BaseSile, + Sile, + SileCDF, + SileBin, ) # Allow geometry to register siles @@ -165,4 +175,4 @@ from . import viz # Make these things publicly available -__all__ = [s for s in dir() if not s.startswith('_')] +__all__ = [s for s in dir() if not s.startswith("_")] diff --git a/src/sisl/_array.py b/src/sisl/_array.py index f228ac69cf..d6b7dd9024 100644 --- a/src/sisl/_array.py +++ b/src/sisl/_array.py @@ -21,12 +21,12 @@ __all__ = ["broadcast_shapes"] -def _append(name, suffix='ilfd'): +def _append(name, suffix="ilfd"): return [name + s for s in suffix] def broadcast_shapes(*shapes): - """ Calculate the broad-casted shape of a list of shapes + """Calculate the broad-casted shape of a list of shapes This should be replaced by np.broadcast_shapes when 1.20 is the default. """ @@ -36,7 +36,7 @@ def broadcast_shapes(*shapes): def array_arange(start, end=None, n=None, dtype=int64): - """ Creates a single array from a sequence of `numpy.arange` + """Creates a single array from a sequence of `numpy.arange` Parameters ---------- @@ -94,12 +94,12 @@ def array_arange(start, end=None, n=None, dtype=int64): return cumsum(a, dtype=dtype) -__all__ += ['array_arange'] +__all__ += ["array_arange"] # Create all partial objects for creating arrays array_arangei = _partial(array_arange, dtype=int32) array_arangel = _partial(array_arange, dtype=int64) -__all__ += _append('array_arange', 'il') +__all__ += _append("array_arange", "il") zeros = np.zeros zerosi = _partial(zeros, dtype=int32) @@ -108,7 +108,7 @@ def array_arange(start, end=None, n=None, dtype=int64): zerosd = _partial(zeros, dtype=float64) zerosc = _partial(zeros, dtype=complex64) zerosz = _partial(zeros, dtype=complex128) -__all__ += _append('zeros', 'ilfdcz') +__all__ += _append("zeros", "ilfdcz") ones = np.ones onesi = _partial(ones, dtype=int32) @@ -117,7 +117,7 @@ def array_arange(start, end=None, n=None, dtype=int64): onesd = _partial(ones, dtype=float64) onesc = _partial(ones, dtype=complex64) onesz = _partial(ones, dtype=complex128) -__all__ += _append('ones', 'ilfdcz') +__all__ += _append("ones", "ilfdcz") empty = np.empty emptyi = _partial(empty, dtype=int32) @@ -126,7 +126,7 @@ def array_arange(start, end=None, n=None, dtype=int64): emptyd = _partial(empty, dtype=float64) emptyc = _partial(empty, dtype=complex64) emptyz = _partial(empty, dtype=complex128) -__all__ += _append('empty', 'ilfdcz') +__all__ += _append("empty", "ilfdcz") array = np.array arrayi = _partial(array, dtype=int32) @@ -135,7 +135,7 @@ def array_arange(start, end=None, n=None, dtype=int64): arrayd = _partial(array, dtype=float64) arrayc = _partial(array, dtype=complex64) arrayz = _partial(array, dtype=complex128) -__all__ += _append('array', 'ilfdcz') +__all__ += _append("array", "ilfdcz") asarray = np.asarray asarrayi = _partial(asarray, dtype=int32) @@ -144,7 +144,7 @@ def array_arange(start, end=None, n=None, dtype=int64): asarrayd = _partial(asarray, dtype=float64) asarrayc = _partial(asarray, dtype=complex64) asarrayz = _partial(asarray, dtype=complex128) -__all__ += _append('asarray', 'ilfdcz') + ['asarray'] +__all__ += _append("asarray", "ilfdcz") + ["asarray"] fromiter = np.fromiter fromiteri = _partial(fromiter, dtype=int32) @@ -153,7 +153,7 @@ def array_arange(start, end=None, n=None, dtype=int64): fromiterd = _partial(fromiter, dtype=float64) fromiterc = _partial(fromiter, dtype=complex64) fromiterz = _partial(fromiter, dtype=complex128) -__all__ += _append('fromiter', 'ilfdcz') +__all__ += _append("fromiter", "ilfdcz") sumi = _partial(np.sum, dtype=int32) suml = _partial(np.sum, dtype=int64) @@ -161,7 +161,7 @@ def array_arange(start, end=None, n=None, dtype=int64): sumd = _partial(np.sum, dtype=float64) sumc = _partial(np.sum, dtype=complex64) sumz = _partial(np.sum, dtype=complex128) -__all__ += _append('sum', 'ilfdcz') +__all__ += _append("sum", "ilfdcz") cumsum = np.cumsum cumsumi = _partial(cumsum, dtype=int32) @@ -170,7 +170,7 @@ def array_arange(start, end=None, n=None, dtype=int64): cumsumd = _partial(cumsum, dtype=float64) cumsumc = _partial(cumsum, dtype=complex64) cumsumz = _partial(cumsum, dtype=complex128) -__all__ += _append('cumsum', 'ilfdcz') +__all__ += _append("cumsum", "ilfdcz") arange = np.arange arangei = _partial(arange, dtype=int32) @@ -179,7 +179,7 @@ def array_arange(start, end=None, n=None, dtype=int64): aranged = _partial(arange, dtype=float64) arangec = _partial(arange, dtype=complex64) arangez = _partial(arange, dtype=complex128) -__all__ += _append('arange', 'ilfdcz') +__all__ += _append("arange", "ilfdcz") prod = np.prod prodi = _partial(prod, dtype=int32) @@ -188,7 +188,7 @@ def array_arange(start, end=None, n=None, dtype=int64): prodd = _partial(prod, dtype=float64) prodc = _partial(prod, dtype=complex64) prodz = _partial(prod, dtype=complex128) -__all__ += _append('prod', 'ilfdcz') +__all__ += _append("prod", "ilfdcz") # Create all partial objects for creating arrays full = np.full @@ -198,13 +198,13 @@ def array_arange(start, end=None, n=None, dtype=int64): fulld = _partial(full, dtype=float64) fullc = _partial(full, dtype=complex64) fullz = _partial(full, dtype=complex128) -__all__ += _append('full', 'ilfdcz') +__all__ += _append("full", "ilfdcz") linspace = np.linspace linspacef = _partial(linspace, dtype=float32) linspaced = _partial(linspace, dtype=float64) linspacec = _partial(linspace, dtype=complex64) linspacez = _partial(linspace, dtype=complex128) -__all__ += _append('linspace', 'fdcz') +__all__ += _append("linspace", "fdcz") del _append, _partial diff --git a/src/sisl/_category.py b/src/sisl/_category.py index 2488c49a18..59c5fde79b 100644 --- a/src/sisl/_category.py +++ b/src/sisl/_category.py @@ -19,7 +19,7 @@ class InstanceCache: - """ Wraps an instance to cache *all* results based on `functools.lru_cache` + """Wraps an instance to cache *all* results based on `functools.lru_cache` Parameters ---------- @@ -102,7 +102,7 @@ def __call__(cls, *args, **kwargs): @set_module("sisl.category") class Category(metaclass=CategoryMeta): - r""" A category """ + r"""A category""" __slots__ = ("_name", "_wrapper") def __init__(self, name=None): @@ -113,17 +113,17 @@ def __init__(self, name=None): @property def name(self): - r""" Name of category """ + r"""Name of category""" return self._name def set_name(self, name): - r""" Override the name of the categorization """ + r"""Override the name of the categorization""" self._name = name @classmethod @abstractmethod def is_class(cls, name, case=True): - r""" Query whether `name` matches the class name by removing a prefix `kw` + r"""Query whether `name` matches the class name by removing a prefix `kw` This is important to ensure that users match the full class name by omitting the prefix returned from this method. @@ -150,7 +150,7 @@ def is_class(cls, name): @classmethod def kw(cls, **kwargs): - """ Create categories based on keywords + """Create categories based on keywords This will search through the inherited classes and return and & category object for all keywords. @@ -184,21 +184,27 @@ def get_cat(cl, args): for cl in subcls: if cl.is_class(key): if found: - raise ValueError(f"{cls.__name__}.kw got a non-unique argument for category name:\n" - f" Searching for {key} and found matches {found.__name__} and {cl.__name__}.") + raise ValueError( + f"{cls.__name__}.kw got a non-unique argument for category name:\n" + f" Searching for {key} and found matches {found.__name__} and {cl.__name__}." + ) found = cl if found is None: for cl in subcls: if cl.is_class(key, case=False): if found: - raise ValueError(f"{cls.__name__}.kw got a non-unique argument for category name:\n" - f" Searching for {key} and found matches {found.__name__.lower()} and {cl.__name__.lower()}.") + raise ValueError( + f"{cls.__name__}.kw got a non-unique argument for category name:\n" + f" Searching for {key} and found matches {found.__name__.lower()} and {cl.__name__.lower()}." + ) found = cl if found is None: - raise ValueError(f"{cls.__name__}.kw got an argument for category name:\n" - f" Searching for {key} but found no matches.") + raise ValueError( + f"{cls.__name__}.kw got an argument for category name:\n" + f" Searching for {key} but found no matches." + ) if cat is None: cat = get_cat(found, args) @@ -209,20 +215,20 @@ def get_cat(cl, args): @abstractmethod def categorize(self, *args, **kwargs): - r""" Do categorization """ + r"""Do categorization""" pass def __str__(self): - r""" String representation of the class (non-distinguishable between equivalent classifiers) """ + r"""String representation of the class (non-distinguishable between equivalent classifiers)""" return self.name def __repr__(self): - r""" String representation of the class (non-distinguishable between equivalent classifiers) """ + r"""String representation of the class (non-distinguishable between equivalent classifiers)""" return self.name @singledispatchmethod def __eq__(self, other): - """ Comparison of two categories, they are compared by class-type """ + """Comparison of two categories, they are compared by class-type""" # This is not totally safe since composites *could* be generated # in different sequences and result in the same boolean expression. # This we do not check and thus are not fool proof... @@ -233,9 +239,12 @@ def __eq__(self, other): # (A ^ B ^ C) != (C ^ A ^ B) if isinstance(self, CompositeCategory): if isinstance(other, CompositeCategory): - return (self.__class__ is other.__class__ and - (self.A == other.A and self.B == other.B or - self.A == other.B and self.B == other.A)) + return self.__class__ is other.__class__ and ( + self.A == other.A + and self.B == other.B + or self.A == other.B + and self.B == other.A + ) # if neither is a compositecategory, then they cannot # be the same category return False @@ -277,6 +286,7 @@ class GenericCategory(Category): and `CompositeCategory` and distinguish them from categories that have a specific object in which they act. """ + @classmethod def is_class(cls, name): # never allow one to match a generic class @@ -286,7 +296,7 @@ def is_class(cls, name): @set_module("sisl.category") class NullCategory(GenericCategory): - r""" Special Null class which always represents a classification not being *anything* """ + r"""Special Null class which always represents a classification not being *anything*""" __slots__ = tuple() def __init__(self): @@ -312,7 +322,8 @@ def name(self): @set_module("sisl.category") class NotCategory(GenericCategory): - """ A class returning the *opposite* of this class (NullCategory) if it is categorized as such """ + """A class returning the *opposite* of this class (NullCategory) if it is categorized as such""" + __slots__ = ("_cat",) def __init__(self, cat): @@ -324,10 +335,11 @@ def __init__(self, cat): self._cat = cat def categorize(self, *args, **kwargs): - r""" Base method for queriyng whether an object is a certain category """ + r"""Base method for queriyng whether an object is a certain category""" cat = self._cat.categorize(*args, **kwargs) _null = NullCategory() + def check(cat): if isinstance(cat, NullCategory): return self @@ -371,7 +383,7 @@ def name(self): @set_module("sisl.category") class CompositeCategory(GenericCategory): - """ A composite class consisting of two categories, an abstract class to always be inherited + """A composite class consisting of two categories, an abstract class to always be inherited This should take 2 categories as arguments @@ -382,6 +394,7 @@ class CompositeCategory(GenericCategory): B : Category the right hand side of the set operation """ + __slots__ = ("A", "B") def __init__(self, A, B): @@ -391,24 +404,20 @@ def __init__(self, A, B): self.A = A self.B = B - def __init_subclass__(cls, /, - composite_name: str, - **kwargs): + def __init_subclass__(cls, /, composite_name: str, **kwargs): super().__init_subclass__(**kwargs) cls.name = _composite_name(composite_name) def categorizeAB(self, *args, **kwargs): - r""" Base method for queriyng whether an object is a certain category """ + r"""Base method for queriyng whether an object is a certain category""" catA = self.A.categorize(*args, **kwargs) catB = self.B.categorize(*args, **kwargs) return catA, catB - - @set_module("sisl.category") class OrCategory(CompositeCategory, composite_name="|"): - """ A class consisting of two categories + """A class consisting of two categories This should take 2 categories as arguments and a binary operator to define how the categories are related. @@ -420,10 +429,11 @@ class OrCategory(CompositeCategory, composite_name="|"): B : Category the right hand side of the set operation """ + __slots__ = tuple() def categorize(self, *args, **kwargs): - r""" Base method for queriyng whether an object is a certain category """ + r"""Base method for queriyng whether an object is a certain category""" catA, catB = self.categorizeAB(*args, **kwargs) def cmp(a, b): @@ -438,7 +448,7 @@ def cmp(a, b): @set_module("sisl.category") class AndCategory(CompositeCategory, composite_name="&"): - """ A class consisting of two categories + """A class consisting of two categories This should take 2 categories as arguments and a binary operator to define how the categories are related. @@ -450,10 +460,11 @@ class AndCategory(CompositeCategory, composite_name="&"): B : Category the right hand side of the set operation """ + __slots__ = tuple() def categorize(self, *args, **kwargs): - r""" Base method for queriyng whether an object is a certain category """ + r"""Base method for queriyng whether an object is a certain category""" catA, catB = self.categorizeAB(*args, **kwargs) def cmp(a, b): @@ -470,7 +481,7 @@ def cmp(a, b): @set_module("sisl.category") class XOrCategory(CompositeCategory, composite_name="⊕"): - """ A class consisting of two categories + """A class consisting of two categories This should take 2 categories as arguments and a binary operator to define how the categories are related. @@ -482,10 +493,11 @@ class XOrCategory(CompositeCategory, composite_name="⊕"): B : Category the right hand side of the set operation """ + __slots__ = tuple() def categorize(self, *args, **kwargs): - r""" Base method for queriyng whether an object is a certain category """ + r"""Base method for queriyng whether an object is a certain category""" catA, catB = self.categorizeAB(*args, **kwargs) def cmp(a, b): diff --git a/src/sisl/_common.py b/src/sisl/_common.py index 0abc630ebf..49e19ee4ba 100644 --- a/src/sisl/_common.py +++ b/src/sisl/_common.py @@ -9,10 +9,11 @@ @unique class Opt(Flag): - """ Global option arguments used throughout sisl + """Global option arguments used throughout sisl These flags may be combined via bit-wise operations """ + NONE = auto() ANY = auto() ALL = auto() diff --git a/src/sisl/_dispatch_class.py b/src/sisl/_dispatch_class.py index c2190db7aa..343ee682c2 100644 --- a/src/sisl/_dispatch_class.py +++ b/src/sisl/_dispatch_class.py @@ -32,13 +32,16 @@ class A(_Dispatchs, class _Dispatchs: """Subclassable for creating the new/to arguments""" - def __init_subclass__(cls, /, - dispatchs: Optional[Union[str, Sequence[Any]]]=None, - when_subclassing: Optional[str] = None, - **kwargs): + def __init_subclass__( + cls, + /, + dispatchs: Optional[Union[str, Sequence[Any]]] = None, + when_subclassing: Optional[str] = None, + **kwargs, + ): # complete the init_subclass super().__init_subclass__(**kwargs) - + # Get the allowed actions for subclassing prefix = "_tonew" allowed_subclassing = ("keep", "new", "copy") @@ -60,7 +63,6 @@ def find_base(cls, attr): loop = [] for attr in dispatchs: - # argument could be: # dispatchs = [ # ("new", "keep"), @@ -95,7 +97,6 @@ def find_base(cls, attr): obj = getattr(base, attr).copy() loop.append((attr, obj, when_subcls)) - if when_subclassing is None: # first non-None value when_subclassing = "copy" @@ -105,10 +106,11 @@ def find_base(cls, attr): _log.debug(f"{cls!r} when_subclassing = {when_subclassing}") if when_subclassing not in allowed_subclassing: - raise ValueError(f"when_subclassing should be one of {allowed_subclassing}, got {when_subclassing}") + raise ValueError( + f"when_subclassing should be one of {allowed_subclassing}, got {when_subclassing}" + ) for attr, obj, _ in loop: - if obj is None: _log.debug(f"Doing nothing for {attr} on class {cls!r}") else: diff --git a/src/sisl/_dispatcher.py b/src/sisl/_dispatcher.py index 14c97346d7..9eccc1bab8 100644 --- a/src/sisl/_dispatcher.py +++ b/src/sisl/_dispatcher.py @@ -14,19 +14,26 @@ from functools import update_wrapper __all__ = [ - "AbstractDispatch", "ObjectDispatcher", "MethodDispatcher", - "ErrorDispatcher", "ClassDispatcher", "TypeDispatcher" + "AbstractDispatch", + "ObjectDispatcher", + "MethodDispatcher", + "ErrorDispatcher", + "ClassDispatcher", + "TypeDispatcher", ] _log = logging.getLogger("sisl") _log.info(f"adding logger: {__name__}") _log = logging.getLogger(__name__) + def _dict_to_str(name, d, parser=None): - """ Convert a dict to __str__ representation """ + """Convert a dict to __str__ representation""" if parser is None: + def parser(kv): return f" {kv[0]}: {kv[1]}" + d_str = ",\n ".join(map(parser, d.items())) if len(d_str) > 0: return f"{name} ({len(d)}): [\n {d_str}\n ]" @@ -34,7 +41,7 @@ def parser(kv): class AbstractDispatch(metaclass=ABCMeta): - r""" Dispatcher class used for dispatching function calls """ + r"""Dispatcher class used for dispatching function calls""" def __init__(self, obj, **attrs): self._obj = obj @@ -44,12 +51,12 @@ def __init__(self, obj, **attrs): _log.info(f"__init__ {self.__class__.__name__}", extra={"obj": self}) def copy(self): - """ Create a copy of this object (will not copy `obj`) """ + """Create a copy of this object (will not copy `obj`)""" _log.debug(f"copy {self.__class__.__name__}", extra={"obj": self}) return self.__class__(self._obj, **self._attrs) def renew(self, **attrs): - """ Create a new class with updated attributes """ + """Create a new class with updated attributes""" _log.debug(f"renew {self.__class__.__name__}", extra={"obj": self}) return self.__class__(self._obj, **{**self._attrs, **attrs}) @@ -92,7 +99,7 @@ def _get_class(self, allow_instance=False): @abstractmethod def dispatch(self, method): - """ Create dispatched method with correctly wrapped documentation + """Create dispatched method with correctly wrapped documentation This should return a function that mimics method but wraps it in some way. @@ -114,7 +121,7 @@ def __getattr__(self, key): class AbstractDispatcher(metaclass=ABCMeta): - """ A container for dispatchers + """A container for dispatchers This is an abstract class holding the dispatch classes (`AbstractDispatch`) and the attributes that are associated with the dispatchers. @@ -137,14 +144,18 @@ def __init__(self, dispatchs=None, default=None, **attrs): _log.info(f"__init__ {self.__class__.__name__}", extra={"obj": self}) def copy(self): - """ Create a copy of this object (making a new child for the dispatch lookup) """ + """Create a copy of this object (making a new child for the dispatch lookup)""" _log.debug(f"copy {self.__class__.__name__}", extra={"obj": self}) return self.__class__(self._dispatchs.new_child(), self._default, **self._attrs) def renew(self, **attrs): - """ Create a new class with updated attributes """ - _log.debug(f"renew {self.__class__.__name__}{tuple(attrs.keys())}", extra={"obj": self}) - return self.__class__(self._dispatchs, self._default, **{**self._attrs, **attrs}) + """Create a new class with updated attributes""" + _log.debug( + f"renew {self.__class__.__name__}{tuple(attrs.keys())}", extra={"obj": self} + ) + return self.__class__( + self._dispatchs, self._default, **{**self._attrs, **attrs} + ) def __len__(self): return len(self._dispatchs) @@ -161,6 +172,7 @@ def toline(kv): if k == self._default: return f"*{k} = {v}" return f" {k} = {v}" + dispatchs = _dict_to_str("dispatchs", self._dispatchs, parser=toline) attrs = _dict_to_str("attrs", self._attrs) if len(attrs) == 0: @@ -170,7 +182,7 @@ def toline(kv): return f"{self.__name__}{{{dispatchs},\n {attrs}\n}}" def __setitem__(self, key, dispatch): - """ Registers a dispatch method (using `register` with default values) + """Registers a dispatch method (using `register` with default values) Parameters ---------- @@ -182,11 +194,11 @@ def __setitem__(self, key, dispatch): self.register(key, dispatch) def __dir__(self): - """ Return instances belonging to this object """ + """Return instances belonging to this object""" return list(self._dispatchs.keys()) + ["renew", "register"] def register(self, key, dispatch, default=False, overwrite=True): - """ Register a dispatch class to this container + """Register a dispatch class to this container Parameter --------- @@ -200,28 +212,33 @@ def register(self, key, dispatch, default=False, overwrite=True): if true and `key` already exists in the list of dispatchs, then it will be overwritten, otherwise a `LookupError` is raised. """ - _log.info(f"register {self.__class__.__name__}(key: {key})", extra={"obj": self}) + _log.info( + f"register {self.__class__.__name__}(key: {key})", extra={"obj": self} + ) if key in self._dispatchs.maps[0] and not overwrite: - raise LookupError(f"{self.__class__.__name__} already has {key} registered (and overwrite is false)") + raise LookupError( + f"{self.__class__.__name__} already has {key} registered (and overwrite is false)" + ) self._dispatchs[key] = dispatch if default: self._default = key class ErrorDispatcher(AbstractDispatcher): - """ Faulty handler to ensure that certain operations are not allowed + """Faulty handler to ensure that certain operations are not allowed This may for instance be used with ``ClassDispatcher(instance_dispatcher=ErrorDispatcher)`` to ensure that a certain dispatch attribute will never be called on an instance. It won't work on type_dispatcher due to not being able to call `register`. """ - def __init__(self, obj, *args, **kwargs): # pylint: disable=W0231 - raise ValueError(f"Dispatcher on {obj} must not be called in this way, see documentation.") + def __init__(self, obj, *args, **kwargs): # pylint: disable=W0231 + raise ValueError( + f"Dispatcher on {obj} must not be called in this way, see documentation." + ) class MethodDispatcher(AbstractDispatcher): - def __init__(self, method, dispatchs=None, default=None, obj=None, **attrs): super().__init__(dispatchs, default, **attrs) update_wrapper(self, method) @@ -236,32 +253,45 @@ def __init__(self, method, dispatchs=None, default=None, obj=None, **attrs): _log.info(f"__init__ {self.__class__.__name__}", extra={"obj": self}) def copy(self): - """ Create a copy of this object (making a new child for the dispatch lookup) """ + """Create a copy of this object (making a new child for the dispatch lookup)""" _log.debug(f"copy {self.__class__.__name__}", extra={"obj": self}) - return self.__class__(self.__wrapped__, self._dispatchs.new_child(), self._default, - self._obj, **self._attrs) + return self.__class__( + self.__wrapped__, + self._dispatchs.new_child(), + self._default, + self._obj, + **self._attrs, + ) def renew(self, **attrs): - """ Create a new class with updated attributes """ + """Create a new class with updated attributes""" _log.debug(f"renew {self.__class__.__name__}", extra={"obj": self}) - return self.__class__(self.__wrapped__, self._dispatchs, self._default, - self._obj, **{**self._attrs, **attrs}) + return self.__class__( + self.__wrapped__, + self._dispatchs, + self._default, + self._obj, + **{**self._attrs, **attrs}, + ) def __call__(self, *args, **kwargs): _log.debug(f"call {self.__class__.__name__}{args}", extra={"obj": self}) if self._default is None: return self.__wrapped__(*args, **kwargs) - return self._dispatchs[self._default](self._obj, **self._attrs).dispatch(self.__wrapped__)(*args, **kwargs) + return self._dispatchs[self._default](self._obj, **self._attrs).dispatch( + self.__wrapped__ + )(*args, **kwargs) def __getitem__(self, key): - r""" Get method using dispatch according to `key` """ - _log.debug(f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self}) + r"""Get method using dispatch according to `key`""" + _log.debug( + f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self} + ) return self._dispatchs[key](self._obj, **self._attrs).dispatch(self.__wrapped__) __getattr__ = __getitem__ - def _parse_obj_getattr(func): """Parse `func` for all methods""" if func is None: @@ -271,16 +301,22 @@ def _parse_obj_getattr(func): # One can make getattr fail regardless of what to fetch from # the object if func == "error": + def func(obj, key): - raise AttributeError(f"{obj} does not implement the '{key}' dispatcher, " - "are you using it incorrectly?") + raise AttributeError( + f"{obj} does not implement the '{key}' dispatcher, " + "are you using it incorrectly?" + ) + return func - raise NotImplementedError(f"Defaulting the obj_getattr argument only accepts [error], got {func}.") + raise NotImplementedError( + f"Defaulting the obj_getattr argument only accepts [error], got {func}." + ) return func class ObjectDispatcher(AbstractDispatcher): - """ A dispatcher relying on object lookups + """A dispatcher relying on object lookups This dispatcher wraps a method call with lookup tables and possible defaults. @@ -299,7 +335,15 @@ class ObjectDispatcher(AbstractDispatcher): hello world """ - def __init__(self, obj, dispatchs=None, default=None, cls_attr_name=None, obj_getattr=None, **attrs): + def __init__( + self, + obj, + dispatchs=None, + default=None, + cls_attr_name=None, + obj_getattr=None, + **attrs, + ): super().__init__(dispatchs, default, **attrs) self._obj = obj self._obj_getattr = _parse_obj_getattr(obj_getattr) @@ -307,27 +351,33 @@ def __init__(self, obj, dispatchs=None, default=None, cls_attr_name=None, obj_ge _log.info(f"__init__ {self.__class__.__name__}", extra={"obj": self}) def copy(self): - """ Create a copy of this object (making a new child for the dispatch lookup) """ + """Create a copy of this object (making a new child for the dispatch lookup)""" _log.debug(f"copy {self.__class__.__name__}", extra={"obj": self}) - return self.__class__(self._obj, - dispatchs=self._dispatchs.new_child(), - default=self._default, - cls_attr_name=self._cls_attr_name, - obj_getattr=self._obj_getattr, - **self._attrs) + return self.__class__( + self._obj, + dispatchs=self._dispatchs.new_child(), + default=self._default, + cls_attr_name=self._cls_attr_name, + obj_getattr=self._obj_getattr, + **self._attrs, + ) def renew(self, **attrs): - """ Create a new class with updated attributes """ + """Create a new class with updated attributes""" _log.debug(f"renew {self.__class__.__name__}", extra={"obj": self}) - return self.__class__(self._obj, - dispatchs=self._dispatchs, - default=self._default, - cls_attr_name=self._cls_attr_name, - obj_getattr=self._obj_getattr, - **{**self._attrs, **attrs}) + return self.__class__( + self._obj, + dispatchs=self._dispatchs, + default=self._default, + cls_attr_name=self._cls_attr_name, + obj_getattr=self._obj_getattr, + **{**self._attrs, **attrs}, + ) def __call__(self, **attrs): - _log.debug(f"call {self.__class__.__name__}{tuple(attrs.keys())}", extra={"obj": self}) + _log.debug( + f"call {self.__class__.__name__}{tuple(attrs.keys())}", extra={"obj": self} + ) return self.renew(**attrs) def __repr__(self): @@ -340,7 +390,7 @@ def __str__(self): return super().__str__().replace("{", f"{{\n {obj},\n ", 1) def register(self, key, dispatch, default=False, overwrite=True, to_class=True): - """ Register a dispatch class to this object and to the object class instance (if existing) + """Register a dispatch class to this object and to the object class instance (if existing) Parameter --------- @@ -359,7 +409,9 @@ def register(self, key, dispatch, default=False, overwrite=True, to_class=True): whether the dispatch class will also be registered with the contained object's class instance """ - _log.info(f"register {self.__class__.__name__}(key: {key})", extra={"obj": self}) + _log.info( + f"register {self.__class__.__name__}(key: {key})", extra={"obj": self} + ) super().register(key, dispatch, default, overwrite) if to_class: cls_dispatch = getattr(self._obj.__class__, self._cls_attr_name, None) @@ -373,28 +425,39 @@ def __exit__(self, exc_type, exc_value, traceback): pass def __getitem__(self, key): - r""" Retrieve dispatched dispatchs by hash (allows functions to be dispatched) """ - _log.info(f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self}) + r"""Retrieve dispatched dispatchs by hash (allows functions to be dispatched)""" + _log.info( + f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self} + ) return self._dispatchs[key](self._obj, **self._attrs) def __getattr__(self, key): - """ Retrieve dispatched method by name, or if the name does not exist return a MethodDispatcher """ + """Retrieve dispatched method by name, or if the name does not exist return a MethodDispatcher""" if key in self._dispatchs: - _log.info(f"__getattr__ {self.__class__.__name__},dispatch={key}", extra={"obj": self}) + _log.info( + f"__getattr__ {self.__class__.__name__},dispatch={key}", + extra={"obj": self}, + ) return self._dispatchs[key](self._obj, **self._attrs) attr = self._obj_getattr(self._obj, key) if callable(attr): - _log.info(f"__getattr__ {self.__class__.__name__},method-dispatch={key}", extra={"obj": self}) + _log.info( + f"__getattr__ {self.__class__.__name__},method-dispatch={key}", + extra={"obj": self}, + ) # This will also ensure that if the user calls immediately after it will use the default - return MethodDispatcher(attr, self._dispatchs, self._default, - self._obj, **self._attrs) - _log.info(f"__getattr__ {self.__class__.__name__},method={key}", extra={"obj": self}) + return MethodDispatcher( + attr, self._dispatchs, self._default, self._obj, **self._attrs + ) + _log.info( + f"__getattr__ {self.__class__.__name__},method={key}", extra={"obj": self} + ) return attr class TypeDispatcher(ObjectDispatcher): - """ A dispatcher relying on type lookups + """A dispatcher relying on type lookups This dispatcher may be called directly and will query the dispatch method through the type of the first argument. @@ -411,7 +474,7 @@ class TypeDispatcher(ObjectDispatcher): """ def register(self, key, dispatch, default=False, overwrite=True, to_class=True): - """ Register a dispatch class to this object and to the object class instance (if existing) + """Register a dispatch class to this object and to the object class instance (if existing) Parameter --------- @@ -430,7 +493,9 @@ def register(self, key, dispatch, default=False, overwrite=True, to_class=True): whether the dispatch class will also be registered with the contained object's class instance """ - _log.info(f"register {self.__class__.__name__}(key: {key})", extra={"obj": self}) + _log.info( + f"register {self.__class__.__name__}(key: {key})", extra={"obj": self} + ) super().register(key, dispatch, default, overwrite, to_class=False) if to_class: cls_dispatch = getattr(self._obj, self._cls_attr_name, None) @@ -453,13 +518,15 @@ def __call__(self, obj, *args, **kwargs): return self._dispatchs[typ](self._obj)(obj, *args, **kwargs) def __getitem__(self, key): - r""" Retrieve dispatched dispatchs by hash (allows functions to be dispatched) """ - _log.info(f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self}) + r"""Retrieve dispatched dispatchs by hash (allows functions to be dispatched)""" + _log.info( + f"__getitem__ {self.__class__.__name__},key={key}", extra={"obj": self} + ) return self._dispatchs[key](self._obj, **self._attrs) class ClassDispatcher(AbstractDispatcher): - """ A dispatcher for classes, using `__get__` it converts into `ObjectDispatcher` upon invocation from an object, or a `TypeDispatcher` when invoked from a class + """A dispatcher for classes, using `__get__` it converts into `ObjectDispatcher` upon invocation from an object, or a `TypeDispatcher` when invoked from a class This is a class-placeholder allowing a dispatcher to be a class attribute and converted into an `ObjectDispatcher` when invoked from an object. @@ -493,11 +560,16 @@ class ClassDispatcher(AbstractDispatcher): The above defers any attributes to the contained `A.sub` attribute. """ - def __init__(self, attr_name, dispatchs=None, default=None, - obj_getattr=None, - instance_dispatcher=ObjectDispatcher, - type_dispatcher=TypeDispatcher, - **attrs): + def __init__( + self, + attr_name, + dispatchs=None, + default=None, + obj_getattr=None, + instance_dispatcher=ObjectDispatcher, + type_dispatcher=TypeDispatcher, + **attrs, + ): # obj_getattr is necessary for the ObjectDispatcher to create the correct # MethodDispatcher super().__init__(dispatchs, default, **attrs) @@ -510,23 +582,33 @@ def __init__(self, attr_name, dispatchs=None, default=None, _log.info(f"__init__ {self.__class__.__name__}", extra={"obj": self}) def copy(self): - """ Create a copy of this object (making a new child for the dispatch lookup) """ + """Create a copy of this object (making a new child for the dispatch lookup)""" _log.debug(f"copy {self.__class__.__name__}", extra={"obj": self}) - return self.__class__(self._attr_name, self._dispatchs.new_child(), self._default, - self._obj_getattr, - self._get.instance, self._get.type, - **self._attrs) + return self.__class__( + self._attr_name, + self._dispatchs.new_child(), + self._default, + self._obj_getattr, + self._get.instance, + self._get.type, + **self._attrs, + ) def renew(self, **attrs): - """ Create a new class with updated attributes """ + """Create a new class with updated attributes""" _log.debud(f"renew {self.__class__.__name__}", extra={"obj": self}) - return self.__class__(self._attr_name, self._dispatchs, self._default, - self._obj_getattr, - self._get.instance, self._get.type, - **{**self._attrs, **attrs}) + return self.__class__( + self._attr_name, + self._dispatchs, + self._default, + self._obj_getattr, + self._get.instance, + self._get.type, + **{**self._attrs, **attrs}, + ) def __get__(self, instance, owner): - """ Class dispatcher retrieval + """Class dispatcher retrieval The returned class depends on the setup of the `ClassDispatcher`. @@ -546,13 +628,20 @@ class object. inst = owner else: inst = instance - _log.debug(f"__get__ {self.__class__.__name__},instance={instance!r},inst={inst!r},owner={owner!r},cls={cls!r}", extra={"obj": self}) + _log.debug( + f"__get__ {self.__class__.__name__},instance={instance!r},inst={inst!r},owner={owner!r},cls={cls!r}", + extra={"obj": self}, + ) if cls is None: return self - return cls(inst, self._dispatchs, default=self._default, - cls_attr_name=self._attr_name, - obj_getattr=self._obj_getattr, - **self._attrs) + return cls( + inst, + self._dispatchs, + default=self._default, + cls_attr_name=self._attr_name, + obj_getattr=self._obj_getattr, + **self._attrs, + ) ''' diff --git a/src/sisl/_environ.py b/src/sisl/_environ.py index 92d66ed461..bd7b802f70 100644 --- a/src/sisl/_environ.py +++ b/src/sisl/_environ.py @@ -17,7 +17,7 @@ @contextmanager def sisl_environ(**environ): - r""" Create a new context for temporary overwriting the sisl environment variables + r"""Create a new context for temporary overwriting the sisl environment variables Parameters ---------- @@ -29,14 +29,17 @@ def sisl_environ(**environ): for key, value in environ.items(): old[key] = SISL_ENVIRON[key]["value"] SISL_ENVIRON[key]["value"] = value - yield # nothing to yield + yield # nothing to yield for key in environ: SISL_ENVIRON[key]["value"] = old[key] -def register_environ_variable(name: str , default: Any, - description: str=None, - process: Callable[[Any], Any]=None): +def register_environ_variable( + name: str, + default: Any, + description: str = None, + process: Callable[[Any], Any] = None, +): """Register a new global sisl environment variable. Parameters @@ -60,6 +63,7 @@ def register_environ_variable(name: str , default: Any, if not name.startswith("SISL_"): raise ValueError("register_environ_variable: name should start with 'SISL_'") if process is None: + def process(arg): return arg @@ -77,7 +81,7 @@ def process(arg): def get_environ_variable(name: str): - """ Gets the value of a registered environment variable. + """Gets the value of a registered environment variable. Parameters ----------- @@ -94,37 +98,59 @@ def get_environ_variable(name: str): except Exception: _nprocs = 1 -register_environ_variable("SISL_NUM_PROCS", min(1, _nprocs), - "Maximum number of CPU's used for parallel computing", - process=int) - -register_environ_variable("SISL_TMP", ".sisl_tmp", - "Path where temporary files should be stored", - process=Path) - -register_environ_variable("SISL_CONFIGDIR", Path.home() / ".config" / "sisl", - "Directory where configuration files for sisl should be stored", - process=Path) - -register_environ_variable("SISL_FILES_TESTS", "_THIS_DIRECTORY_DOES_NOT_EXIST_", - dedent("""\ +register_environ_variable( + "SISL_NUM_PROCS", + min(1, _nprocs), + "Maximum number of CPU's used for parallel computing", + process=int, +) + +register_environ_variable( + "SISL_TMP", ".sisl_tmp", "Path where temporary files should be stored", process=Path +) + +register_environ_variable( + "SISL_CONFIGDIR", + Path.home() / ".config" / "sisl", + "Directory where configuration files for sisl should be stored", + process=Path, +) + +register_environ_variable( + "SISL_FILES_TESTS", + "_THIS_DIRECTORY_DOES_NOT_EXIST_", + dedent( + """\ Full path of the sisl/files folder. Generally this is only used for tests and for documentations. - """), - process=Path) - -register_environ_variable("SISL_VIZ_AUTOLOAD", "false", - dedent("""\ + """ + ), + process=Path, +) + +register_environ_variable( + "SISL_VIZ_AUTOLOAD", + "false", + dedent( + """\ Determines whether the visualization module is automatically loaded. It may be good to leave auto load off if you are doing performance critical calculations to avoid the overhead of loading the visualization module. - """), - process=lambda val: val and val.lower().strip() in ["1", "t", "true"]) - -register_environ_variable("SISL_SHOW_PROGRESS", "false", - "Whether routines which can enable progress bars should show them by default or not.", - process=lambda val: val and val.lower().strip() in ["1", "t", "true"]) - -register_environ_variable("SISL_IO_DEFAULT", "", - "The default DFT code for processing files, Siles will be compared with endswidth(<>).", - process=lambda val: val.lower()) + """ + ), + process=lambda val: val and val.lower().strip() in ["1", "t", "true"], +) + +register_environ_variable( + "SISL_SHOW_PROGRESS", + "false", + "Whether routines which can enable progress bars should show them by default or not.", + process=lambda val: val and val.lower().strip() in ["1", "t", "true"], +) + +register_environ_variable( + "SISL_IO_DEFAULT", + "", + "The default DFT code for processing files, Siles will be compared with endswidth(<>).", + process=lambda val: val.lower(), +) diff --git a/src/sisl/_help.py b/src/sisl/_help.py index fb6ae8ca6a..615ad40043 100644 --- a/src/sisl/_help.py +++ b/src/sisl/_help.py @@ -25,6 +25,7 @@ try: from defusedxml import __version__ as defusedxml_version from defusedxml.ElementTree import parse as xml_parse + try: defusedxml_version = list(map(int, defusedxml_version.split("."))) if defusedxml_version[0] == 0 and defusedxml_version[1] <= 5: @@ -49,8 +50,10 @@ def array_fill_repeat(array, size, cls=None): if size % len(array) != 0: # We do not have it correctly formatted (either an integer # repeatable part, full, or a single) - raise ValueError("Repetition of or array is not divisible with actual length. " - "Hence we cannot create a repeated size.") + raise ValueError( + "Repetition of or array is not divisible with actual length. " + "Hence we cannot create a repeated size." + ) if cls is None: if reps > 1: return np.tile(array, reps) @@ -63,7 +66,7 @@ def array_fill_repeat(array, size, cls=None): @set_module("sisl") def voigt_matrix(M, to_voigt): - r""" Convert a matrix from Voigt representation to dense, or from matrix to Voigt + r"""Convert a matrix from Voigt representation to dense, or from matrix to Voigt Parameters ---------- @@ -76,30 +79,31 @@ def voigt_matrix(M, to_voigt): """ if to_voigt: m = np.empty(M.shape[:-2] + (6,), dtype=M.dtype) - m[..., 0] = M[..., 0, 0] # xx - m[..., 1] = M[..., 1, 1] # yy - m[..., 2] = M[..., 2, 2] # zz - m[..., 3] = (M[..., 2, 1] + M[..., 1, 2]) * 0.5 # zy - m[..., 4] = (M[..., 2, 0] + M[..., 0, 2]) * 0.5 # zx - m[..., 5] = (M[..., 1, 0] + M[..., 0, 1]) * 0.5 # xy + m[..., 0] = M[..., 0, 0] # xx + m[..., 1] = M[..., 1, 1] # yy + m[..., 2] = M[..., 2, 2] # zz + m[..., 3] = (M[..., 2, 1] + M[..., 1, 2]) * 0.5 # zy + m[..., 4] = (M[..., 2, 0] + M[..., 0, 2]) * 0.5 # zx + m[..., 5] = (M[..., 1, 0] + M[..., 0, 1]) * 0.5 # xy else: m = np.empty(M.shape[:-1] + (3, 3), dtype=M.dtype) - m[..., 0, 0] = M[..., 0] # xx - m[..., 1, 1] = M[..., 1] # yy - m[..., 2, 2] = M[..., 2] # zz - m[..., 0, 1] = M[..., 5] # xy - m[..., 1, 0] = M[..., 5] # xy - m[..., 0, 2] = M[..., 4] # xz - m[..., 2, 0] = M[..., 4] # xz - m[..., 1, 2] = M[..., 3] # zy - m[..., 2, 1] = M[..., 3] # zy + m[..., 0, 0] = M[..., 0] # xx + m[..., 1, 1] = M[..., 1] # yy + m[..., 2, 2] = M[..., 2] # zz + m[..., 0, 1] = M[..., 5] # xy + m[..., 1, 0] = M[..., 5] # xy + m[..., 0, 2] = M[..., 4] # xz + m[..., 2, 0] = M[..., 4] # xz + m[..., 1, 2] = M[..., 3] # zy + m[..., 2, 1] = M[..., 3] # zy return m + _Iterable = collections_abc.Iterable def isiterable(obj): - """ Returns whether the object is an iterable or not """ + """Returns whether the object is an iterable or not""" return isinstance(obj, _Iterable) @@ -107,12 +111,12 @@ def isiterable(obj): def isndarray(arr): - """ Returns ``True`` if the input object is a `numpy.ndarray` object """ + """Returns ``True`` if the input object is a `numpy.ndarray` object""" return isinstance(arr, _ndarray) -def get_dtype(var, int=None, other=None): # pylint: disable=W0622 - """ Returns the `numpy.dtype` equivalent of `var`. +def get_dtype(var, int=None, other=None): # pylint: disable=W0622 + """Returns the `numpy.dtype` equivalent of `var`. Parameters ---------- @@ -173,7 +177,7 @@ def get_dtype(var, int=None, other=None): # pylint: disable=W0622 def dtype_complex_to_real(dtype): - """ Return the equivalent precision real data-type if the `dtype` is complex """ + """Return the equivalent precision real data-type if the `dtype` is complex""" if dtype == np.complex128: return np.float64 elif dtype == np.complex64: @@ -182,7 +186,7 @@ def dtype_complex_to_real(dtype): def dtype_real_to_complex(dtype): - """ Return the equivalent precision complex data-type if the `dtype` is real """ + """Return the equivalent precision complex data-type if the `dtype` is real""" if dtype == np.float64: return np.complex128 elif dtype == np.float32: @@ -191,7 +195,7 @@ def dtype_real_to_complex(dtype): def array_replace(array, *replace, **kwargs): - """ Replace values in `array` using `replace` + """Replace values in `array` using `replace` Replaces values in `array` using tuple/val in `replace`. @@ -231,7 +235,7 @@ def array_replace(array, *replace, **kwargs): def wrap_filterwarnings(*args, **kwargs): - """ Instead of creating nested `with` statements one can wrap entire functions with a filter + """Instead of creating nested `with` statements one can wrap entire functions with a filter The following two are equivalent: @@ -250,11 +254,14 @@ def wrap_filterwarnings(*args, **kwargs): **kwargs : keyword arguments passed to `warnings.filterwarnings` """ + def decorator(func): @functools.wraps(func) def wrap_func(*func_args, **func_kwargs): with warnings.catch_warnings(): warnings.filterwarnings(*args, **kwargs) return func(*func_args, **func_kwargs) + return wrap_func + return decorator diff --git a/src/sisl/_internal.py b/src/sisl/_internal.py index 893a0b04a0..929bd5c781 100644 --- a/src/sisl/_internal.py +++ b/src/sisl/_internal.py @@ -11,8 +11,10 @@ def set_module(module): r"""Decorator for overriding __module__ on a function or class""" + def deco(f_or_c): if module is not None: f_or_c.__module__ = module return f_or_c + return deco diff --git a/src/sisl/_namedindex.py b/src/sisl/_namedindex.py index b6daea0968..f5a2563a84 100644 --- a/src/sisl/_namedindex.py +++ b/src/sisl/_namedindex.py @@ -19,7 +19,7 @@ @set_module("sisl") class NamedIndex: - __slots__ = ('_name', '_index') + __slots__ = ("_name", "_index") def __init__(self, name=None, index=None): if isinstance(name, dict): @@ -39,28 +39,28 @@ def __init__(self, name=None, index=None): @property def names(self): - """ All names contained """ + """All names contained""" return self._name def clear(self): - """ Clear all names in this object, no names will exist after this call (in-place) """ + """Clear all names in this object, no names will exist after this call (in-place)""" self._name = [] self._index = [] def __iter__(self): - """ Iterate names in the group """ + """Iterate names in the group""" yield from self._name def __len__(self): - """ Number of uniquely defined names """ + """Number of uniquely defined names""" return len(self._name) def copy(self): - """ Create a copy of this """ + """Create a copy of this""" return self.__class__(self._name[:], [i.copy() for i in self._index]) def add_name(self, name, index): - """ Add a new named group. The indices (`index`) will be associated with the name `name` + """Add a new named group. The indices (`index`) will be associated with the name `name` Parameters ---------- @@ -70,14 +70,16 @@ def add_name(self, name, index): the indices that has a name associated """ if name in self._name: - raise SislError(f"{self.__class__.__name__}.add_name already contains name {name}, please delete group name before adding.") + raise SislError( + f"{self.__class__.__name__}.add_name already contains name {name}, please delete group name before adding." + ) self._name.append(name) if isinstance(index, ndarray) and index.dtype == bool_: index = np.flatnonzero(index) self._index.append(arrayi(index).ravel()) def delete_name(self, name): - """ Delete an existing named group, if the group does not exist, nothing will happen. + """Delete an existing named group, if the group does not exist, nothing will happen. Parameters ---------- @@ -89,24 +91,24 @@ def delete_name(self, name): del self._index[i] def __str__(self): - """ Representation of object """ + """Representation of object""" N = len(self) if N == 0: - return self.__class__.__name__ + '{}' - s = self.__class__.__name__ + f'{{groups: {N}' + return self.__class__.__name__ + "{}" + s = self.__class__.__name__ + f"{{groups: {N}" for name, idx in zip(self._name, self._index): - s += ',\n {}: [{}]'.format(name, list2str(idx)) - return s + '\n}' + s += ",\n {}: [{}]".format(name, list2str(idx)) + return s + "\n}" def __setitem__(self, name, index): - """ Equivalent to `add_name` """ + """Equivalent to `add_name`""" if isinstance(name, str): self.add_name(name, index) else: self.add_name(index, name) def index(self, name): - """ Return indices of the group via `name` """ + """Return indices of the group via `name`""" try: i = self._name.index(name) return self._index[i] @@ -116,19 +118,19 @@ def index(self, name): return name def __getitem__(self, name): - """ Return indices of the group """ + """Return indices of the group""" return self.index(name) def __delitem__(self, name): - """ Delete a named group """ + """Delete a named group""" self.delete_name(name) def __contains__(self, name): - """ Check whether a name exists in this group a named group """ + """Check whether a name exists in this group a named group""" return name in self._name def merge(self, other, offset=0, duplicate="raise"): - """ Return a new object which is a merge of self and other + """Return a new object which is a merge of self and other By default, name conflicts between self and other will raise a ValueError. See the `duplicate` parameter for information on how to change this. @@ -165,7 +167,11 @@ def merge(self, other, offset=0, duplicate="raise"): pass elif duplicate == "raise": - raise ValueError("{}.merge has overlapping names without a duplicate handler: {}".format(self.__class__.__name__, list(intersection))) + raise ValueError( + "{}.merge has overlapping names without a duplicate handler: {}".format( + self.__class__.__name__, list(intersection) + ) + ) elif duplicate == "union": # Indices are made a union @@ -191,12 +197,14 @@ def merge(self, other, offset=0, duplicate="raise"): del new[name] else: - raise ValueError(f"{self.__class__.__name__}.merge wrong argument: duplicate.") + raise ValueError( + f"{self.__class__.__name__}.merge wrong argument: duplicate." + ) return new def sub_index(self, index, name=None): - """ Get a new object with only the indexes in idx. + """Get a new object with only the indexes in idx. Parameters ---------- @@ -221,7 +229,7 @@ def sub_index(self, index, name=None): return self.__class__(self._name[:], new_index) def sub_name(self, names): - """ Get a new object with only the names in `names` + """Get a new object with only the names in `names` Parameters ---------- @@ -233,7 +241,9 @@ def sub_name(self, names): for name in names: if name not in self._name: - raise ValueError(f"{self.__class__.__name__}.sub_name specified name ({name}) is not in object.") + raise ValueError( + f"{self.__class__.__name__}.sub_name specified name ({name}) is not in object." + ) new_index = [] for name in self: @@ -245,7 +255,7 @@ def sub_name(self, names): return self.__class__(names, new_index) def remove_index(self, index): - """ Remove indices from all named index groups + """Remove indices from all named index groups Parameters ---------- @@ -261,7 +271,7 @@ def remove_index(self, index): return new def reduce(self): - """ Removes names from the object which have no index associated (in-place) """ + """Removes names from the object which have no index associated (in-place)""" for i in range(len(self))[::-1]: if len(self._index[i]) == 0: del self[self.names[i]] diff --git a/src/sisl/_plot.py b/src/sisl/_plot.py index 56d682b841..5122b23764 100644 --- a/src/sisl/_plot.py +++ b/src/sisl/_plot.py @@ -10,6 +10,7 @@ import matplotlib as mlib import matplotlib.pyplot as mlibplt import mpl_toolkits.mplot3d as mlib3d + has_matplotlib = True except Exception as _matplotlib_import_exception: mlib = NotImplementedError @@ -17,7 +18,7 @@ mlib3d = NotImplementedError has_matplotlib = False -__all__ = ['plot', 'mlib', 'mlibplt', 'mlib3d', 'get_axes'] +__all__ = ["plot", "mlib", "mlibplt", "mlib3d", "get_axes"] def get_axes(axes=False, **kwargs): @@ -33,16 +34,21 @@ def get_axes(axes=False, **kwargs): def _plot(obj, *args, **kwargs): try: - a = getattr(obj, '__plot__') + a = getattr(obj, "__plot__") except AttributeError: - raise NotImplementedError(f"{obj.__class__.__name__} does not implement the __plot__ method.") + raise NotImplementedError( + f"{obj.__class__.__name__} does not implement the __plot__ method." + ) return a(*args, **kwargs) + if has_matplotlib: plot = _plot else: + def plot(obj, *args, **kwargs): - raise _matplotlib_import_exception # pylint: disable=E0601 + raise _matplotlib_import_exception # pylint: disable=E0601 + # Clean up del has_matplotlib diff --git a/src/sisl/_typing.py b/src/sisl/_typing.py index ac4a05d2f6..c64b54780e 100644 --- a/src/sisl/_typing.py +++ b/src/sisl/_typing.py @@ -10,7 +10,7 @@ from sisl import Atom, Atoms, Geometry, Lattice -#from typing import TYPE_CHECKING, final +# from typing import TYPE_CHECKING, final AtomsLike = Union[ Atom, diff --git a/src/sisl/_typing_ext/numpy.py b/src/sisl/_typing_ext/numpy.py index ced180bd7a..1e67cdff36 100644 --- a/src/sisl/_typing_ext/numpy.py +++ b/src/sisl/_typing_ext/numpy.py @@ -3,7 +3,7 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. from numpy import __version__ -if tuple(map(int, __version__.split('.'))) >= (1, 21, 0): +if tuple(map(int, __version__.split("."))) >= (1, 21, 0): # NDArray entered in 1.21. # numpy.typing entered in 1.20.0 # we have numpy typing diff --git a/src/sisl/atom.py b/src/sisl/atom.py index d00423a85f..353cc64864 100644 --- a/src/sisl/atom.py +++ b/src/sisl/atom.py @@ -21,12 +21,12 @@ from .orbital import Orbital from .shape import Sphere -__all__ = ['PeriodicTable', 'Atom', 'AtomUnknown', 'AtomGhost', 'Atoms'] +__all__ = ["PeriodicTable", "Atom", "AtomUnknown", "AtomGhost", "Atoms"] @set_module("sisl") class PeriodicTable: - r""" Periodic table for creating an `Atom`, or retrieval of atomic information via atomic numbers + r"""Periodic table for creating an `Atom`, or retrieval of atomic information via atomic numbers Enables *lookup* of atomic numbers/names/labels to get the atomic number. @@ -800,7 +800,7 @@ class PeriodicTable: # fmt: on def Z(self, key): - """ Atomic number based on general input + """Atomic number based on general input Return the atomic number corresponding to the `key` lookup. @@ -834,7 +834,7 @@ def Z(self, key): Z_int = Z def Z_label(self, key): - """ Atomic label of the corresponding atom + """Atomic label of the corresponding atom Return the atomic short name corresponding to the `key` lookup. @@ -859,7 +859,7 @@ def Z_label(self, key): Z_short = Z_label def atomic_mass(self, key): - """ Atomic mass of the corresponding atom + """Atomic mass of the corresponding atom Return the atomic mass corresponding to the `key` lookup. @@ -878,11 +878,11 @@ def atomic_mass(self, key): Z = self.Z_int(key) get = self._atomic_mass.get if isinstance(Z, (Integral, Real)): - return get(Z, 0.) + return get(Z, 0.0) return _a.fromiterd(map(get, Z, iter(float, 1))) - def radius(self, key, method='calc'): - """ Atomic radii using different methods + def radius(self, key, method="calc"): + """Atomic radii using different methods Return the atomic radii. @@ -908,17 +908,20 @@ def radius(self, key, method='calc'): if isinstance(Z, Integral): return func(Z) / 100 return _a.fromiterd(map(func, Z)) / 100 + radii = radius + # Create a local instance of the periodic table to # faster look up _ptbl = PeriodicTable() class AtomMeta(type): - """ Meta class for key-lookup on the class. """ + """Meta class for key-lookup on the class.""" + def __getitem__(cls, key): - """ Create a new atom object """ + """Create a new atom object""" if isinstance(key, Atom): # if the key already is an atomic object # return it @@ -930,7 +933,7 @@ def __getitem__(cls, key): elif isinstance(key, (list, tuple)): # The key is a list, # we need to create a list of atoms - return [cls[k] for k in key] # pylint: disable=E1136 + return [cls[k] for k in key] # pylint: disable=E1136 # Index Z based return cls(key) @@ -940,14 +943,13 @@ def __getitem__(cls, key): # class ...(..., metaclass=MetaClass) # This below construct handles both python2 and python3 cases @set_module("sisl") -class Atom(_Dispatchs, - dispatchs=[("to", - ClassDispatcher("to", - type_dispatcher=None))], - when_subclassing="keep", - metaclass=AtomMeta, - ): - """ Atomic information, mass, name number of orbitals and ranges +class Atom( + _Dispatchs, + dispatchs=[("to", ClassDispatcher("to", type_dispatcher=None))], + when_subclassing="keep", + metaclass=AtomMeta, +): + """Atomic information, mass, name number of orbitals and ranges Object to handle atomic mass, name, number of orbitals and orbital range. @@ -974,8 +976,9 @@ class Atom(_Dispatchs, arbitrary designation for user handling similar atoms with different settings (defaults to the label of the atom) """ + def __new__(cls, *args, **kwargs): - """ Figure out which class to actually use """ + """Figure out which class to actually use""" # Handle the case where no arguments are passed (e.g. for serializing stuff) if len(args) == 0 and "Z" not in kwargs: return super().__new__(cls) @@ -1025,13 +1028,15 @@ def __init__(self, Z, orbitals=None, mass=None, tag=None, **kwargs): if self._orbitals is None: if orbitals is not None: - raise ValueError(f"{self.__class__.__name__}.__init__ got unparseable 'orbitals' argument: {orbitals}") - if 'R' in kwargs: + raise ValueError( + f"{self.__class__.__name__}.__init__ got unparseable 'orbitals' argument: {orbitals}" + ) + if "R" in kwargs: # backwards compatibility (possibly remove this in the future) - R = _a.asarrayd(kwargs['R']).ravel() + R = _a.asarrayd(kwargs["R"]).ravel() self._orbitals = [Orbital(r) for r in R] else: - self._orbitals = [Orbital(-1.)] + self._orbitals = [Orbital(-1.0)] if mass is None: self._mass = _ptbl.atomic_mass(self.Z) @@ -1048,40 +1053,40 @@ def __hash__(self): @property def Z(self) -> NDArray[np.int32]: - """ Atomic number """ + """Atomic number""" return self._Z @property def orbitals(self): - """ List of orbitals """ + """List of orbitals""" return self._orbitals @property def mass(self) -> NDArray[np.float64]: - """ Atomic mass """ + """Atomic mass""" return self._mass @property def tag(self) -> str: - """ Tag for atom """ + """Tag for atom""" return self._tag @property def no(self) -> int: - """ Number of orbitals on this atom """ + """Number of orbitals on this atom""" return len(self.orbitals) def index(self, orbital): - """ Return the index of the orbital in the atom object """ + """Return the index of the orbital in the atom object""" if not isinstance(orbital, Orbital): orbital = self[orbital] for i, o in enumerate(self.orbitals): if o == orbital: return i - raise KeyError('Could not find `orbital` in the list of orbitals.') + raise KeyError("Could not find `orbital` in the list of orbitals.") def sub(self, orbitals): - """ Return the same atom with only a subset of the orbitals present + """Return the same atom with only a subset of the orbitals present Parameters ---------- @@ -1100,15 +1105,19 @@ def sub(self, orbitals): """ orbitals = _a.arrayi(orbitals).ravel() if len(orbitals) > self.no: - raise ValueError(f"{self.__class__.__name__}.sub tries to remove more than the number of orbitals on an atom.") + raise ValueError( + f"{self.__class__.__name__}.sub tries to remove more than the number of orbitals on an atom." + ) if np.any(orbitals >= self.no): - raise ValueError(f"{self.__class__.__name__}.sub tries to remove a non-existing orbital io > no.") + raise ValueError( + f"{self.__class__.__name__}.sub tries to remove a non-existing orbital io > no." + ) orbs = [self.orbitals[o].copy() for o in orbitals] return self.copy(orbitals=orbs) def remove(self, orbitals): - """ Return the same atom without a specific set of orbitals + """Return the same atom without a specific set of orbitals Parameters ---------- @@ -1128,16 +1137,18 @@ def remove(self, orbitals): return self.sub(orbs) def copy(self, Z=None, orbitals=None, mass=None, tag=None): - """ Return copy of this object """ + """Return copy of this object""" if orbitals is None: orbitals = [orb.copy() for orb in self] - return self.__class__(self.Z if Z is None else Z, - orbitals, - self.mass if mass is None else mass, - self.tag if tag is None else tag) + return self.__class__( + self.Z if Z is None else Z, + orbitals, + self.mass if mass is None else mass, + self.tag if tag is None else tag, + ) - def radius(self, method='calc'): - """ Return the atomic radii of the atom (in Ang) + def radius(self, method="calc"): + """Return the atomic radii of the atom (in Ang) See `PeriodicTable.radius` for details on the argument. """ @@ -1145,11 +1156,11 @@ def radius(self, method='calc'): @property def symbol(self): - """ Return short atomic name (Au==79). """ + """Return short atomic name (Au==79).""" return _ptbl.Z_short(self.Z) def __getitem__(self, key): - """ The orbital corresponding to index `key` """ + """The orbital corresponding to index `key`""" if isinstance(key, slice): ol = key.indices(len(self)) return [self.orbitals[o] for o in range(*ol)] @@ -1164,14 +1175,14 @@ def __getitem__(self, key): return [self.orbitals[o] for o in key] def maxR(self): - """ Return the maximum range of orbitals. """ + """Return the maximum range of orbitals.""" mR = -1e10 for o in self.orbitals: mR = max(mR, o.R) return mR def scale(self, scale): - """ Scale the atomic radii and return an equivalent atom. + """Scale the atomic radii and return an equivalent atom. Parameters ---------- @@ -1183,11 +1194,11 @@ def scale(self, scale): return new def __iter__(self): - """ Loop on all orbitals in this atom """ + """Loop on all orbitals in this atom""" yield from self.orbitals def iter(self, group=False): - """ Loop on all orbitals in this atom + """Loop on all orbitals in this atom Parameters ---------- @@ -1219,18 +1230,23 @@ def iter(self, group=False): def __str__(self): # Create orbitals output - orbs = ',\n '.join([str(o) for o in self.orbitals]) - return self.__class__.__name__ + '{{{0}, Z: {1:d}, mass(au): {2:.5f}, maxR: {3:.5f},\n {4}\n}}'.format(self.tag, self.Z, self.mass, self.maxR(), orbs) + orbs = ",\n ".join([str(o) for o in self.orbitals]) + return ( + self.__class__.__name__ + + "{{{0}, Z: {1:d}, mass(au): {2:.5f}, maxR: {3:.5f},\n {4}\n}}".format( + self.tag, self.Z, self.mass, self.maxR(), orbs + ) + ) def __repr__(self): return f"<{self.__module__}.{self.__class__.__name__} {self.tag}, Z={self.Z}, M={self.mass}, maxR={self.maxR()}, no={len(self.orbitals)}>" def __len__(self): - """ Return number of orbitals in this atom """ + """Return number of orbitals in this atom""" return self.no def __getattr__(self, attr): - """ Pass attribute calls to the orbital classes and return lists/array + """Pass attribute calls to the orbital classes and return lists/array Parameters ---------- @@ -1254,7 +1270,9 @@ def __getattr__(self, attr): if found == 0: # we never got any values, reraise the AttributeError - raise AttributeError(f"'{self.__class__.__name__}.orbitals' objects has no attribute '{attr}'") + raise AttributeError( + f"'{self.__class__.__name__}.orbitals' objects has no attribute '{attr}'" + ) # Now parse the data, currently we'll only allow Integral, Real, Complex if is_Integral: @@ -1265,11 +1283,13 @@ def __getattr__(self, attr): elif is_Real: for io in range(len(vals)): if vals[io] is None: - vals[io] = 0. + vals[io] = 0.0 return _a.arrayd(vals) elif is_callable: + def _ret_none(*args, **kwargs): return None + for io in range(len(vals)): if vals[io] is None: vals[io] = _ret_none @@ -1278,16 +1298,20 @@ def _ret_none(*args, **kwargs): class ArrayCall: def __init__(self, methods): self.methods = methods + def __call__(self, *args, **kwargs): return [m(*args, **kwargs) for m in self.methods] + return ArrayCall(vals) # We don't know how to handle this, simply return... return vals - @deprecation("toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15") + @deprecation( + "toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15" + ) def toSphere(self, center=None): - """ Return a sphere with the maximum orbital radius equal + """Return a sphere with the maximum orbital radius equal Returns ------- @@ -1297,7 +1321,7 @@ def toSphere(self, center=None): return self.to.Sphere(center=center) def equal(self, other, R=True, psi=False): - """ True if `other` is the same as this atomic specie + """True if `other` is the same as this atomic specie Parameters ---------- @@ -1313,14 +1337,19 @@ def equal(self, other, R=True, psi=False): same = self.Z == other.Z same &= self.no == other.no if same and R: - same &= all([self.orbitals[i].equal(other.orbitals[i], psi=psi) for i in range(self.no)]) + same &= all( + [ + self.orbitals[i].equal(other.orbitals[i], psi=psi) + for i in range(self.no) + ] + ) same &= np.isclose(self.mass, other.mass) same &= self.tag == other.tag return same # Check whether they are equal def __eq__(self, b): - """ Return true if the saved quantities are the same """ + """Return true if the saved quantities are the same""" return self.equal(b) def __ne__(self, b): @@ -1328,18 +1357,23 @@ def __ne__(self, b): # Create pickling routines def __getstate__(self): - """ Return the state of this object """ - return {'Z': self.Z, 'orbitals': self.orbitals, 'mass': self.mass, 'tag': self.tag} + """Return the state of this object""" + return { + "Z": self.Z, + "orbitals": self.orbitals, + "mass": self.mass, + "tag": self.tag, + } def __setstate__(self, d): - """ Re-create the state of this object """ - self.__init__(d['Z'], d['orbitals'], d['mass'], d['tag']) + """Re-create the state of this object""" + self.__init__(d["Z"], d["orbitals"], d["mass"], d["tag"]) @set_module("sisl") class AtomUnknown(Atom): def __init__(self, Z, *args, **kwargs): - """ Instantiate with overridden tag """ + """Instantiate with overridden tag""" if len(args) < 3 and "tag" not in kwargs: kwargs["tag"] = "unknown" super().__init__(Z, *args, **kwargs) @@ -1348,9 +1382,9 @@ def __init__(self, Z, *args, **kwargs): @set_module("sisl") class AtomGhost(AtomUnknown): def __init__(self, Z, *args, **kwargs): - """ Instantiate with overridden tag and taking the absolute value of Z """ + """Instantiate with overridden tag and taking the absolute value of Z""" if Z < 0: - Z = - Z + Z = -Z if len(args) < 3 and "tag" not in kwargs: kwargs["tag"] = "ghost" if len(args) < 2 and "mass" not in kwargs: @@ -1360,18 +1394,23 @@ def __init__(self, Z, *args, **kwargs): class AtomToDispatch(AbstractDispatch): - """ Base dispatcher from class passing from an Atom class """ + """Base dispatcher from class passing from an Atom class""" + to_dispatch = Atom.to + + class ToSphereDispatch(AtomToDispatch): def dispatch(self, *args, center=None, **kwargs): return Sphere(self._get_object().maxR(), center) + + to_dispatch.register("Sphere", ToSphereDispatch) @set_module("sisl") class Atoms: - """ A list-like object to contain a list of different atoms with minimum + """A list-like object to contain a list of different atoms with minimum data duplication. This holds multiple `Atom` objects which are indexed via a species @@ -1400,7 +1439,7 @@ class repeated `na` times. Creating a set of 4 atoms, 2 Hydrogen, 2 Helium, in an alternate ordere - + >>> Z = [1, 2, 1, 2] >>> atoms = Atoms(Z) @@ -1414,7 +1453,6 @@ class repeated `na` times. __slots__ = ("_atom", "_specie", "_firsto") def __init__(self, atoms="H", na=None): - # Default value of the atom object if atoms is None: atoms = Atom("H") @@ -1470,13 +1508,13 @@ def __init__(self, atoms="H", na=None): self._update_orbitals() def _update_orbitals(self): - """ Internal routine for updating the `firsto` attribute """ + """Internal routine for updating the `firsto` attribute""" # Get number of orbitals per specie uorbs = _a.arrayi([a.no for a in self.atom]) self._firsto = np.insert(_a.cumsumi(uorbs[self.specie]), 0, 0) def copy(self): - """ Return a copy of this atom """ + """Return a copy of this atom""" atoms = Atoms() atoms._atom = [a.copy() for a in self._atom] atoms._specie = np.copy(self._specie) @@ -1485,48 +1523,48 @@ def copy(self): @property def atom(self): - """ List of unique atoms in this group of atoms """ + """List of unique atoms in this group of atoms""" return self._atom @property def nspecie(self): - """ Number of different species """ + """Number of different species""" return len(self._atom) @property def specie(self): - """ List of atomic species """ + """List of atomic species""" return self._specie @property def no(self): - """ Total number of orbitals in this list of atoms """ + """Total number of orbitals in this list of atoms""" uorbs = _a.arrayi([a.no for a in self.atom]) return uorbs[self.specie].sum() @property def orbitals(self): - """ Array of orbitals of the contained objects """ + """Array of orbitals of the contained objects""" return np.diff(self.firsto) @property def firsto(self): - """ First orbital of the corresponding atom in the consecutive list of orbitals """ + """First orbital of the corresponding atom in the consecutive list of orbitals""" return self._firsto @property def lasto(self): - """ Last orbital of the corresponding atom in the consecutive list of orbitals """ + """Last orbital of the corresponding atom in the consecutive list of orbitals""" return self._firsto[1:] - 1 @property def q0(self): - """ Initial charge per atom """ + """Initial charge per atom""" q0 = _a.arrayd([a.q0.sum() for a in self.atom]) return q0[self.specie] def orbital(self, io): - """ Return an array of orbital of the contained objects """ + """Return an array of orbital of the contained objects""" io = _a.asarrayi(io) ndim = io.ndim io = io.ravel() % self.no @@ -1539,7 +1577,7 @@ def orbital(self, io): return [self.atom[ia].orbitals[o] for ia, o in zip(a, io)] def maxR(self, all=False): - """ The maximum radius of the atoms + """The maximum radius of the atoms Parameters ---------- @@ -1555,18 +1593,18 @@ def maxR(self, all=False): @property def mass(self): - """ Array of masses of the contained objects """ + """Array of masses of the contained objects""" umass = _a.arrayd([a.mass for a in self.atom]) return umass[self.specie[:]] @property def Z(self): - """ Array of atomic numbers """ + """Array of atomic numbers""" uZ = _a.arrayi([a.Z for a in self.atom]) return uZ[self.specie[:]] def scale(self, scale): - """ Scale the atomic radii and return an equivalent atom. + """Scale the atomic radii and return an equivalent atom. Parameters ---------- @@ -1579,20 +1617,20 @@ def scale(self, scale): return atoms def index(self, atom): - """ Return the indices of the atom object """ + """Return the indices of the atom object""" return (self._specie == self.specie_index(atom)).nonzero()[0] def specie_index(self, atom): - """ Return the species index of the atom object """ + """Return the species index of the atom object""" if not isinstance(atom, Atom): atom = self[atom] for s, a in enumerate(self.atom): if a == atom: return s - raise KeyError('Could not find `atom` in the list of atoms.') + raise KeyError("Could not find `atom` in the list of atoms.") def group_atom_data(self, data, axis=0): - r""" Group data for each atom based on number of orbitals + r"""Group data for each atom based on number of orbitals This is useful for grouping data that is orbitally resolved. This will return a list of length ``len(self)`` and with each item @@ -1619,7 +1657,7 @@ def group_atom_data(self, data, axis=0): return np.split(data, self.lasto[:-1] + 1, axis=axis) def reorder(self, in_place=False): - """ Reorders the atoms and species index so that they are ascending (starting with a specie that exists) + """Reorders the atoms and species index so that they are ascending (starting with a specie that exists) Parameters ---------- @@ -1674,6 +1712,7 @@ def formula(self, system="Hill"): if systeml == "hill": # sort lexographically symbols = sorted(c.keys()) + def parse(symbol_c): symbol, c = symbol_c if c == 1: @@ -1682,10 +1721,12 @@ def parse(symbol_c): return "".join(map(parse, sorted(c.items()))) - raise ValueError(f"{self.__class__.__name__}.formula got unrecognized argument 'system' {system}") + raise ValueError( + f"{self.__class__.__name__}.formula got unrecognized argument 'system' {system}" + ) def reduce(self, in_place=False): - """ Returns a new `Atoms` object by removing non-used atoms """ + """Returns a new `Atoms` object by removing non-used atoms""" if in_place: atoms = self else: @@ -1710,7 +1751,7 @@ def reduce(self, in_place=False): return atoms def sub(self, atoms): - """ Return a subset of the list """ + """Return a subset of the list""" atoms = _a.asarrayi(atoms).ravel() new_atoms = Atoms() new_atoms._atom = self._atom[:] @@ -1719,27 +1760,27 @@ def sub(self, atoms): return new_atoms def remove(self, atoms): - """ Remove a set of atoms """ + """Remove a set of atoms""" atoms = _a.asarrayi(atoms).ravel() idx = np.setdiff1d(np.arange(len(self)), atoms, assume_unique=True) return self.sub(idx) def tile(self, reps): - """ Tile this atom object """ + """Tile this atom object""" atoms = self.copy() atoms._specie = np.tile(atoms._specie, reps) atoms._update_orbitals() return atoms def repeat(self, reps): - """ Repeat this atom object """ + """Repeat this atom object""" atoms = self.copy() atoms._specie = np.repeat(atoms._specie, reps) atoms._update_orbitals() return atoms def swap(self, a, b): - """ Swaps all atoms """ + """Swaps all atoms""" a = _a.asarrayi(a) b = _a.asarrayi(b) atoms = self.copy() @@ -1750,7 +1791,7 @@ def swap(self, a, b): return atoms def swap_atom(self, a, b): - """ Swap specie index positions """ + """Swap specie index positions""" speciea = self.specie_index(a) specieb = self.specie_index(b) @@ -1758,14 +1799,17 @@ def swap_atom(self, a, b): idx_b = (self._specie == specieb).nonzero()[0] atoms = self.copy() - atoms._atom[speciea], atoms._atom[specieb] = atoms._atom[specieb], atoms._atom[speciea] + atoms._atom[speciea], atoms._atom[specieb] = ( + atoms._atom[specieb], + atoms._atom[speciea], + ) atoms._specie[idx_a] = specieb atoms._specie[idx_b] = speciea atoms._update_orbitals() return atoms def append(self, other): - """ Append `other` to this list of atoms and return the appended version + """Append `other` to this list of atoms and return the appended version Parameters ---------- @@ -1801,7 +1845,7 @@ def prepend(self, other): return other.append(self) def reverse(self, atoms=None): - """ Returns a reversed geometry + """Returns a reversed geometry Also enables reversing a subset of the atoms. """ @@ -1814,7 +1858,7 @@ def reverse(self, atoms=None): return copy def insert(self, index, other): - """ Insert other atoms into the list of atoms at index """ + """Insert other atoms into the list of atoms at index""" if isinstance(other, Atom): other = Atoms(other) else: @@ -1836,21 +1880,21 @@ def insert(self, index, other): return atoms def __str__(self): - """ Return the `Atoms` in str """ + """Return the `Atoms` in str""" s = f"{self.__class__.__name__}{{species: {len(self._atom)},\n" for a, idx in self.iter(True): - s += ' {1}: {0},\n'.format(len(idx), str(a).replace('\n', '\n ')) + s += " {1}: {0},\n".format(len(idx), str(a).replace("\n", "\n ")) return f"{s}}}" def __repr__(self): return f"<{self.__module__}.{self.__class__.__name__} nspecies={len(self._atom)}, na={len(self)}, no={self.no}>" def __len__(self): - """ Return number of atoms in the object """ + """Return number of atoms in the object""" return len(self._specie) def iter(self, species=False): - """ Loop on all atoms + """Loop on all atoms This iterator may be used in two contexts: @@ -1874,15 +1918,15 @@ def iter(self, species=False): yield self._atom[s] def __iter__(self): - """ Loop on all atoms with the same specie in order of atoms """ + """Loop on all atoms with the same specie in order of atoms""" yield from self.iter() def __contains__(self, key): - """ Determine whether the `key` is in the unique atoms list """ + """Determine whether the `key` is in the unique atoms list""" return key in self.atom def __getitem__(self, key): - """ Return an `Atom` object corresponding to the key(s) """ + """Return an `Atom` object corresponding to the key(s)""" if isinstance(key, slice): sl = key.indices(len(self)) return [self.atom[self._specie[s]] for s in range(sl[0], sl[1], sl[2])] @@ -1899,8 +1943,8 @@ def __getitem__(self, key): return [self.atom[i] for i in self._specie[key]] def __setitem__(self, key, value): - """ Overwrite an `Atom` object corresponding to the key(s) """ - #If key is a string, we replace the atom that matches 'key' + """Overwrite an `Atom` object corresponding to the key(s)""" + # If key is a string, we replace the atom that matches 'key' if isinstance(key, str): self.replace_atom(self[key], value) return @@ -1928,7 +1972,7 @@ def __setitem__(self, key, value): self._update_orbitals() def replace(self, index, atom): - """ Replace all atomic indices `index` with the atom `atom` (in-place) + """Replace all atomic indices `index` with the atom `atom` (in-place) This is the preferred way of replacing atoms in geometries. @@ -1944,8 +1988,10 @@ def replace(self, index, atom): self.replace_atom(index, atom) return if not isinstance(atom, Atom): - raise TypeError(f"{self.__class__.__name__}.replace requires input arguments to " - "be of the type Atom") + raise TypeError( + f"{self.__class__.__name__}.replace requires input arguments to " + "be of the type Atom" + ) index = _a.asarrayi(index).ravel() # Be sure to add the atom @@ -1959,15 +2005,17 @@ def replace(self, index, atom): for ius in np.unique(self._specie[index]): a = self._atom[ius] if a.no != atom.no: - a1 = ' ' + str(a).replace('\n', '\n ') - a2 = ' ' + str(atom).replace('\n', '\n ') - info(f'Substituting atom\n{a1}\n->\n{a2}\nwith a different number of orbitals!') + a1 = " " + str(a).replace("\n", "\n ") + a2 = " " + str(atom).replace("\n", "\n ") + info( + f"Substituting atom\n{a1}\n->\n{a2}\nwith a different number of orbitals!" + ) self._specie[index] = specie # Update orbital counts... self._update_orbitals() def replace_atom(self, atom_from, atom_to): - """ Replace all atoms equivalent to `atom_from` with `atom_to` (in-place) + """Replace all atoms equivalent to `atom_from` with `atom_to` (in-place) I.e. this is the preferred way of adapting all atoms of a specific type with another one. @@ -1991,11 +2039,15 @@ def replace_atom(self, atom_from, atom_to): if the atoms does not have the same number of orbitals. """ if not isinstance(atom_from, Atom): - raise TypeError(f"{self.__class__.__name__}.replace_atom requires input arguments to " - "be of the class Atom") + raise TypeError( + f"{self.__class__.__name__}.replace_atom requires input arguments to " + "be of the class Atom" + ) if not isinstance(atom_to, Atom): - raise TypeError(f"{self.__class__.__name__}.replace_atom requires input arguments to " - "be of the class Atom") + raise TypeError( + f"{self.__class__.__name__}.replace_atom requires input arguments to " + "be of the class Atom" + ) # Get index of `atom_from` idx_from = self.specie_index(atom_from) @@ -2016,15 +2068,17 @@ def replace_atom(self, atom_from, atom_to): self._atom[idx_from] = atom_to if atom_from.no != atom_to.no: - a1 = ' ' + str(atom_from).replace('\n', '\n ') - a2 = ' ' + str(atom_to).replace('\n', '\n ') - info(f'Replacing atom\n{a1}\n->\n{a2}\nwith a different number of orbitals!') + a1 = " " + str(atom_from).replace("\n", "\n ") + a2 = " " + str(atom_to).replace("\n", "\n ") + info( + f"Replacing atom\n{a1}\n->\n{a2}\nwith a different number of orbitals!" + ) # Update orbital counts... self._update_orbitals() def hassame(self, other, R=True): - """ True if the contained atoms are the same in the two lists + """True if the contained atoms are the same in the two lists Notes ----- @@ -2055,7 +2109,7 @@ def hassame(self, other, R=True): return True def equal(self, other, R=True): - """ True if the contained atoms are the same in the two lists (also checks indices) + """True if the contained atoms are the same in the two lists (also checks indices) Parameters ---------- @@ -2078,8 +2132,10 @@ def equal(self, other, R=True): if is_in == -1: return False # We should check that they also have the same indices - if not np.all(np.nonzero(self.specie == iA)[0] \ - == np.nonzero(other.specie == is_in)[0]): + if not np.all( + np.nonzero(self.specie == iA)[0] + == np.nonzero(other.specie == is_in)[0] + ): return False else: for iB, B in enumerate(other.atom): @@ -2091,23 +2147,24 @@ def equal(self, other, R=True): if is_in == -1: return False # We should check that they also have the same indices - if not np.all(np.nonzero(other.specie == iB)[0] \ - == np.nonzero(self.specie == is_in)[0]): + if not np.all( + np.nonzero(other.specie == iB)[0] + == np.nonzero(self.specie == is_in)[0] + ): return False return True def __eq__(self, b): - """ Returns true if the contained atoms are the same """ + """Returns true if the contained atoms are the same""" return self.equal(b) # Create pickling routines def __getstate__(self): - """ Return the state of this object """ - return {'atom': self.atom, - 'specie': self.specie} + """Return the state of this object""" + return {"atom": self.atom, "specie": self.specie} def __setstate__(self, d): - """ Re-create the state of this object """ + """Re-create the state of this object""" self.__init__() - self._atom = d['atom'] - self._specie = d['specie'] + self._atom = d["atom"] + self._specie = d["specie"] diff --git a/src/sisl/conftest.py b/src/sisl/conftest.py index dd65537e34..59f7a081fd 100644 --- a/src/sisl/conftest.py +++ b/src/sisl/conftest.py @@ -13,29 +13,30 @@ # Here we create the necessary methods and fixtures to enabled/disable # tests depending on whether a sisl-files directory is present. + # Modify items based on whether the env is correct or not def pytest_collection_modifyitems(config, items): sisl_files_tests = _environ.get_environ_variable("SISL_FILES_TESTS") if sisl_files_tests.is_dir(): - if (sisl_files_tests / 'sisl').is_dir(): + if (sisl_files_tests / "sisl").is_dir(): return - print(f'pytest-sisl: Could not locate sisl directory in: {sisl_files_tests}') + print(f"pytest-sisl: Could not locate sisl directory in: {sisl_files_tests}") return xfail_sisl_files = pytest.mark.xfail( - run=False, - reason="requires env(SISL_FILES_TESTS) pointing to clone of: https://github.com/zerothi/sisl-files" + run=False, + reason="requires env(SISL_FILES_TESTS) pointing to clone of: https://github.com/zerothi/sisl-files", ) for item in items: # Only skip those that have the sisl_files fixture # GLOBAL skipping of ALL tests that don't have this fixture - if 'sisl_files' in item.fixturenames: + if "sisl_files" in item.fixturenames: item.add_marker(xfail_sisl_files) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def sisl_tmp(request, tmp_path_factory): - """ sisl specific temporary file and directory creator. + """sisl specific temporary file and directory creator. sisl_tmp(file, dir_name='sisl') sisl_tmp.file(file, dir_name='sisl') @@ -44,24 +45,25 @@ def sisl_tmp(request, tmp_path_factory): The scope of the `sisl_tmp` fixture is at a function level to clean up after each function. """ + class FileFactory: def __init__(self): self.base = tmp_path_factory.getbasetemp() self.dirs = [self.base] self.files = [] - def dir(self, name='sisl'): + def dir(self, name="sisl"): # Make name a path - D = Path(name.replace(os.path.sep, '-')) + D = Path(name.replace(os.path.sep, "-")) if not (self.base / D).is_dir(): # tmp_path_factory.mktemp returns pathlib.Path self.dirs.append(tmp_path_factory.mktemp(str(D), numbered=False)) return self.dirs[-1] - def file(self, name, dir_name='sisl'): + def file(self, name, dir_name="sisl"): # self.base *is* a pathlib - D = self.base / dir_name.replace(os.path.sep, '-') + D = self.base / dir_name.replace(os.path.sep, "-") if D in self.dirs: i = self.dirs.index(D) else: @@ -73,8 +75,8 @@ def file(self, name, dir_name='sisl'): def getbase(self): return self.dirs[-1] - def __call__(self, name, dir_name='sisl'): - """ Shorthand for self.file """ + def __call__(self, name, dir_name="sisl"): + """Shorthand for self.file""" return self.file(name, dir_name) def teardown(self): @@ -98,14 +100,15 @@ def teardown(self): d.rmdir() except Exception: pass + ff = FileFactory() request.addfinalizer(ff.teardown) return ff -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def sisl_files(): - """ Environment catcher for the large files hosted in a different repository. + """Environment catcher for the large files hosted in a different repository. If SISL_FILES_TESTS has been defined in the environment variable the directory will be used for the tests with this as a fixture. @@ -115,10 +118,13 @@ def sisl_files(): """ sisl_files_tests = _environ.get_environ_variable("SISL_FILES_TESTS") if not sisl_files_tests.is_dir(): + def _path(*files): pytest.xfail( - reason=f"Environment SISL_FILES_TESTS not pointing to a valid directory.", - run=False) + reason=f"Environment SISL_FILES_TESTS not pointing to a valid directory.", + run=False, + ) + return _path def _path(*files): @@ -127,42 +133,51 @@ def _path(*files): return p # I expect this test to fail due to the wrong environment. # But it isn't an actual fail since it hasn't runned... - pytest.xfail(reason=f"Environment SISL_FILES_TESTS may point to a wrong path(?); file {p} not found", - run=False) + pytest.xfail( + reason=f"Environment SISL_FILES_TESTS may point to a wrong path(?); file {p} not found", + run=False, + ) + return _path -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def sisl_system(): - """ A preset list of geometries/Hamiltonians. """ + """A preset list of geometries/Hamiltonians.""" + class System: pass d = System() alat = 1.42 - sq3h = 3.**.5 * 0.5 + sq3h = 3.0**0.5 * 0.5 C = Atom(Z=6, R=1.42) - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * alat, - nsc=[3, 3, 1]) - d.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * alat, - atoms=C, lattice=lattice) + lattice = Lattice( + np.array([[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64) + * alat, + nsc=[3, 3, 1], + ) + d.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * alat, + atoms=C, + lattice=lattice, + ) d.R = np.array([0.1, 1.5]) - d.t = np.array([0., 2.7]) - d.tS = np.array([(0., 1.0), - (2.7, 0.)]) + d.t = np.array([0.0, 2.7]) + d.tS = np.array([(0.0, 1.0), (2.7, 0.0)]) d.C = Atom(Z=6, R=max(d.R)) - d.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * alat, - nsc=[3, 3, 1]) - d.gtb = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * alat, - atoms=C, lattice=lattice) + d.lattice = Lattice( + np.array([[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64) + * alat, + nsc=[3, 3, 1], + ) + d.gtb = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * alat, + atoms=C, + lattice=lattice, + ) d.ham = Hamiltonian(d.gtb) d.ham.construct([(0.1, 1.5), (0.1, 2.7)]) @@ -188,7 +203,7 @@ def pytest_ignore_collect(path, config): global _skip_paths parts = list(Path(path).parts) parts.reverse() - sisl_parts = parts[:parts.index("sisl")] + sisl_parts = parts[: parts.index("sisl")] sisl_parts.reverse() sisl_path = str(Path("sisl").joinpath(*sisl_parts)) @@ -199,30 +214,86 @@ def pytest_ignore_collect(path, config): def pytest_configure(config): - pytest.sisl_travis_skip = pytest.mark.skipif( os.environ.get("SISL_TRAVIS_CI", "false").lower() == "true", - reason="running on TRAVIS" + reason="running on TRAVIS", ) # Locally manage pytest.ini input - for mark in ['io', 'generic', 'bloch', 'hamiltonian', 'geometry', 'geom', 'shape', - 'state', 'electron', 'phonon', 'utils', 'unit', 'distribution', - 'spin', 'self_energy', 'help', 'messages', 'namedindex', 'sparse', - 'lattice', 'supercell', 'sc', 'quaternion', 'sparse_geometry', 'sparse_orbital', - 'ranges', 'physics', "physics_feature", - 'orbital', 'oplist', 'grid', 'atoms', 'atom', 'sgrid', 'sdata', 'sgeom', - 'version', - 'bz', 'brillouinzone', "monkhorstpack", "bandstructure", - 'inv', 'eig', 'linalg', - 'density_matrix', 'dynamicalmatrix', 'energydensity_matrix', - 'siesta', 'tbtrans', 'vasp', 'w90', 'wannier90', 'gulp', 'fdf', - "fhiaims", "aims", "orca", - "collection", - "category", "geom_category", "plot", - 'slow', 'selector', 'overlap', 'mixing', - "typing", "only", - 'viz', "processors", "data", "plots", "plotters"]: + for mark in [ + "io", + "generic", + "bloch", + "hamiltonian", + "geometry", + "geom", + "shape", + "state", + "electron", + "phonon", + "utils", + "unit", + "distribution", + "spin", + "self_energy", + "help", + "messages", + "namedindex", + "sparse", + "lattice", + "supercell", + "sc", + "quaternion", + "sparse_geometry", + "sparse_orbital", + "ranges", + "physics", + "physics_feature", + "orbital", + "oplist", + "grid", + "atoms", + "atom", + "sgrid", + "sdata", + "sgeom", + "version", + "bz", + "brillouinzone", + "monkhorstpack", + "bandstructure", + "inv", + "eig", + "linalg", + "density_matrix", + "dynamicalmatrix", + "energydensity_matrix", + "siesta", + "tbtrans", + "vasp", + "w90", + "wannier90", + "gulp", + "fdf", + "fhiaims", + "aims", + "orca", + "collection", + "category", + "geom_category", + "plot", + "slow", + "selector", + "overlap", + "mixing", + "typing", + "only", + "viz", + "processors", + "data", + "plots", + "plotters", + ]: config.addinivalue_line( "markers", f"{mark}: mark test to run only on named environment" ) diff --git a/src/sisl/constant.py b/src/sisl/constant.py index 58af33eb56..0216fe1820 100644 --- a/src/sisl/constant.py +++ b/src/sisl/constant.py @@ -36,12 +36,12 @@ from ._internal import set_module from .unit.base import units -__all__ = ['PhysicalConstant'] +__all__ = ["PhysicalConstant"] @set_module("sisl") class PhysicalConstant(float): - """ Class to create a physical constant with unit-conversion capability, works exactly like a float. + """Class to create a physical constant with unit-conversion capability, works exactly like a float. To change the units simply call it like a method with the desired unit: @@ -56,7 +56,8 @@ class PhysicalConstant(float): >>> m2nm * 2 1000000000.0 """ - __slots__ = ['_unit'] + + __slots__ = ["_unit"] def __new__(cls, value, unit): constant = float.__new__(cls, value) @@ -67,14 +68,14 @@ def __init__(self, value, unit): @property def unit(self): - """ Unit of constant """ + """Unit of constant""" return self._unit def __str__(self): - return '{} {}'.format(float(self), self.unit) + return "{} {}".format(float(self), self.unit) def __call__(self, unit=None): - """ Return the value for the constant in the given unit, otherwise will return the units in SI units """ + """Return the value for the constant in the given unit, otherwise will return the units in SI units""" if unit is None: return self return PhysicalConstant(self * units(self.unit, unit), unit) @@ -85,7 +86,7 @@ def __eq__(self, other): return super().__eq__(self, other) -__all__ += ['q', 'c', 'h', 'hbar', 'm_e', 'm_p', 'G', 'G0', 'a0'] +__all__ += ["q", "c", "h", "hbar", "m_e", "m_p", "G", "G0", "a0"] # These are CODATA-2018 values @@ -109,7 +110,6 @@ def __eq__(self, other): # Values not found in the CODATA table #: Conductance quantum [S], or [m^2/s^2] -G0 = PhysicalConstant(2 * (q ** 2 / h), 'm^2/s^2') +G0 = PhysicalConstant(2 * (q**2 / h), "m^2/s^2") #: Gravitational constant [m^3/kg/s^2] -G = PhysicalConstant(6.6740831e-11, 'm^3/kg/s^2') - +G = PhysicalConstant(6.6740831e-11, "m^3/kg/s^2") diff --git a/src/sisl/geom/_common.py b/src/sisl/geom/_common.py index 832ff9d95e..f96414bde1 100644 --- a/src/sisl/geom/_common.py +++ b/src/sisl/geom/_common.py @@ -7,8 +7,8 @@ def geometry_define_nsc(geometry, periodic=(True, True, True)): - """Define the number of supercells for a geometry based on the periodicity """ - if np.all(geometry.maxR(True) > 0.): + """Define the number of supercells for a geometry based on the periodicity""" + if np.all(geometry.maxR(True) > 0.0): # strip away optimizing nsc for the non-periodic directions axes = [i for i, p in enumerate(periodic) if p] @@ -29,7 +29,7 @@ def geometry_define_nsc(geometry, periodic=(True, True, True)): def geometry2uc(geometry, dx=1e-8): - """ Translate the geometry to the unit cell by first shifting `dx` """ + """Translate the geometry to the unit cell by first shifting `dx`""" geometry = geometry.move(dx).translate2uc().move(-dx) geometry.xyz[geometry.xyz < 0] = 0 return geometry diff --git a/src/sisl/geom/_composite.py b/src/sisl/geom/_composite.py index efe2585c1c..dd0826d28b 100644 --- a/src/sisl/geom/_composite.py +++ b/src/sisl/geom/_composite.py @@ -8,7 +8,6 @@ @dataclass class CompositeGeometrySection: - @abstractmethod def build_section(self, geometry): ... @@ -47,18 +46,19 @@ def composite_geometry(sections, section_cls, **kwargs): **kwargs: Keyword arguments used as defaults for the sections when the . """ + # Parse sections into Section objects def conv(s): # If it is some arbitrary type, convert it to a tuple if not isinstance(s, (section_cls, tuple, dict)): - s = (s, ) + s = (s,) # If we arrived here with a tuple, convert it to a dict if isinstance(s, tuple): s = {field.name: val for field, val in zip(fields(section_cls), s)} # At this point it is either a dict or already a section object. if isinstance(s, dict): return section_cls(**{**kwargs, **s}) - + return copy.copy(s) # Then loop through all the sections. @@ -68,14 +68,14 @@ def conv(s): section = conv(section) new_addition = section.build_section(prev) - + if i == 0: geom = new_addition else: geom = section.add_section(geom, new_addition) - + prev = section - + return geom diff --git a/src/sisl/geom/basic.py b/src/sisl/geom/basic.py index c840b9e673..127d01e546 100644 --- a/src/sisl/geom/basic.py +++ b/src/sisl/geom/basic.py @@ -8,24 +8,23 @@ from ._common import geometry2uc, geometry_define_nsc -__all__ = ['sc', 'bcc', 'fcc', 'hcp', 'rocksalt'] +__all__ = ["sc", "bcc", "fcc", "hcp", "rocksalt"] # A few needed variables _s30 = 1 / 2 -_s60 = 3 ** .5 / 2 -_s45 = 1 / 2 ** .5 +_s60 = 3**0.5 / 2 +_s45 = 1 / 2**0.5 _c30 = _s60 _c60 = _s30 _c45 = _s45 -_t30 = 1 / 3 ** .5 -_t45 = 1. -_t60 = 3 ** .5 +_t30 = 1 / 3**0.5 +_t45 = 1.0 +_t60 = 3**0.5 @set_module("sisl.geom") -def sc(alat: float, - atom): - """ Simple cubic lattice with 1 atom +def sc(alat: float, atom): + """Simple cubic lattice with 1 atom Parameters ---------- @@ -34,19 +33,15 @@ def sc(alat: float, atom : Atom the atom in the SC lattice """ - lattice = Lattice(np.array([[1, 0, 0], - [0, 1, 0], - [0, 0, 1]], np.float64) * alat) + lattice = Lattice(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], np.float64) * alat) g = Geometry([0, 0, 0], atom, lattice=lattice) geometry_define_nsc(g) return g @set_module("sisl.geom") -def bcc(alat: float, - atoms, - orthogonal: bool=False): - """ Body centered cubic lattice with 1 (non-orthogonal) or 2 atoms (orthogonal) +def bcc(alat: float, atoms, orthogonal: bool = False): + """Body centered cubic lattice with 1 (non-orthogonal) or 2 atoms (orthogonal) Parameters ---------- @@ -58,25 +53,23 @@ def bcc(alat: float, whether the lattice is orthogonal (2 atoms) """ if orthogonal: - lattice = Lattice(np.array([[1, 0, 0], - [0, 1, 0], - [0, 0, 1]], np.float64) * alat) + lattice = Lattice( + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], np.float64) * alat + ) ah = alat / 2 g = Geometry([[0, 0, 0], [ah, ah, ah]], atoms, lattice=lattice) else: - lattice = Lattice(np.array([[-1, 1, 1], - [1, -1, 1], - [1, 1, -1]], np.float64) * alat / 2) + lattice = Lattice( + np.array([[-1, 1, 1], [1, -1, 1], [1, 1, -1]], np.float64) * alat / 2 + ) g = Geometry([0, 0, 0], atoms, lattice=lattice) geometry_define_nsc(g) return g @set_module("sisl.geom") -def fcc(alat: float, - atoms, - orthogonal: bool=False): - """ Face centered cubic lattice with 1 (non-orthogonal) or 4 atoms (orthogonal) +def fcc(alat: float, atoms, orthogonal: bool = False): + """Face centered cubic lattice with 1 (non-orthogonal) or 4 atoms (orthogonal) Parameters ---------- @@ -88,27 +81,25 @@ def fcc(alat: float, whether the lattice is orthogonal (4 atoms) """ if orthogonal: - lattice = Lattice(np.array([[1, 0, 0], - [0, 1, 0], - [0, 0, 1]], np.float64) * alat) + lattice = Lattice( + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], np.float64) * alat + ) ah = alat / 2 - g = Geometry([[0, 0, 0], [ah, ah, 0], - [ah, 0, ah], [0, ah, ah]], atoms, lattice=lattice) + g = Geometry( + [[0, 0, 0], [ah, ah, 0], [ah, 0, ah], [0, ah, ah]], atoms, lattice=lattice + ) else: - lattice = Lattice(np.array([[0, 1, 1], - [1, 0, 1], - [1, 1, 0]], np.float64) * alat / 2) + lattice = Lattice( + np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], np.float64) * alat / 2 + ) g = Geometry([0, 0, 0], atoms, lattice=lattice) geometry_define_nsc(g) return g @set_module("sisl.geom") -def hcp(a: float, - atoms, - coa: float=49/30, - orthogonal: bool=False): - """ Hexagonal closed packed lattice with 2 (non-orthogonal) or 4 atoms (orthogonal) +def hcp(a: float, atoms, coa: float = 49 / 30, orthogonal: bool = False): + """Hexagonal closed packed lattice with 2 (non-orthogonal) or 4 atoms (orthogonal) Parameters ---------- @@ -123,15 +114,21 @@ def hcp(a: float, """ # height of hcp structure c = a * coa - a3sq = a / 3 ** .5 + a3sq = a / 3**0.5 if orthogonal: - lattice = Lattice([[a + a * _c60 * 2, 0, 0], - [0, a * _c30 * 2, 0], - [0, 0, c / 2]]) - gt = Geometry([[0, 0, 0], - [a, 0, 0], - [a * _s30, a * _c30, 0], - [a * (1 + _s30), a * _c30, 0]], atoms, lattice=lattice) + lattice = Lattice( + [[a + a * _c60 * 2, 0, 0], [0, a * _c30 * 2, 0], [0, 0, c / 2]] + ) + gt = Geometry( + [ + [0, 0, 0], + [a, 0, 0], + [a * _s30, a * _c30, 0], + [a * (1 + _s30), a * _c30, 0], + ], + atoms, + lattice=lattice, + ) # Create the rotated one on top gr = gt.copy() # mirror structure @@ -143,17 +140,16 @@ def hcp(a: float, g = gt.append(gr, 2) else: lattice = Lattice([a, a, c, 90, 90, 60]) - g = Geometry([[0, 0, 0], [a3sq * _c30, a3sq * _s30, c / 2]], - atoms, lattice=lattice) + g = Geometry( + [[0, 0, 0], [a3sq * _c30, a3sq * _s30, c / 2]], atoms, lattice=lattice + ) geometry_define_nsc(g) return g @set_module("sisl.geom") -def rocksalt(alat: float, - atoms, - orthogonal: bool=False): - """ Two-element rocksalt lattice with 2 (non-orthogonal) or 8 atoms (orthogonal) +def rocksalt(alat: float, atoms, orthogonal: bool = False): + """Two-element rocksalt lattice with 2 (non-orthogonal) or 8 atoms (orthogonal) This is equivalent to the NaCl crystal structure (halite). @@ -163,7 +159,7 @@ def rocksalt(alat: float, lattice parameter atoms : Atom a list of two atoms that the crystal consists of - orthogonal: + orthogonal: whether the lattice is orthogonal or not """ if isinstance(atoms, str): diff --git a/src/sisl/geom/bilayer.py b/src/sisl/geom/bilayer.py index 77f9256094..4768f87b15 100644 --- a/src/sisl/geom/bilayer.py +++ b/src/sisl/geom/bilayer.py @@ -12,19 +12,21 @@ from ._common import geometry_define_nsc -__all__ = ['bilayer'] +__all__ = ["bilayer"] @set_module("sisl.geom") -def bilayer(bond: float=1.42, - bottom_atoms: Optional=None, - top_atoms: Optional=None, - stacking: str='AB', - twist=(0, 0), - separation: float=3.35, - ret_angle: bool=False, - layer: str='both'): - r""" Commensurate unit cell of a hexagonal bilayer structure, possibly with a twist angle. +def bilayer( + bond: float = 1.42, + bottom_atoms: Optional = None, + top_atoms: Optional = None, + stacking: str = "AB", + twist=(0, 0), + separation: float = 3.35, + ret_angle: bool = False, + layer: str = "both", +): + r"""Commensurate unit cell of a hexagonal bilayer structure, possibly with a twist angle. This routine follows the prescription of twisted bilayer graphene found in :cite:`Trambly2010`. @@ -69,11 +71,11 @@ def bilayer(bond: float=1.42, ref_cell = bottom.cell.copy() stacking = stacking.lower() - if stacking == 'aa': + if stacking == "aa": top = top.move([0, 0, separation]) - elif stacking == 'ab': + elif stacking == "ab": top = top.move([-bond, 0, separation]) - elif stacking == 'ba': + elif stacking == "ba": top = top.move([bond, 0, separation]) else: raise ValueError("bilayer: stacking must be one of {AA, AB, BA}") @@ -95,37 +97,35 @@ def bilayer(bond: float=1.42, natoms = 2 else: # Twisting - cos_theta = (n ** 2 + 4 * n * m + m ** 2) / (2 * (n ** 2 + n * m + m ** 2)) + cos_theta = (n**2 + 4 * n * m + m**2) / (2 * (n**2 + n * m + m**2)) theta = acos(cos_theta) * 180 / pi rep = 4 * (n + m) - natoms = 2 * (n ** 2 + n * m + m ** 2) + natoms = 2 * (n**2 + n * m + m**2) if rep > 1: # Set origin through an A atom near the middle of the geometry - align_vec = - rep * (ref_cell[0] + ref_cell[1]) / 2 + align_vec = -rep * (ref_cell[0] + ref_cell[1]) / 2 - bottom = (bottom - .tile(rep, axis=0) - .tile(rep, axis=1) - .move(align_vec)) + bottom = bottom.tile(rep, axis=0).tile(rep, axis=1).move(align_vec) # Set new lattice vectors bottom.cell[0] = n * ref_cell[0] + m * ref_cell[1] bottom.cell[1] = -m * ref_cell[0] + (n + m) * ref_cell[1] # Remove atoms outside cell - cell_box = Cuboid(bottom.cell, center=[- bond * 1e-4] * 3) + cell_box = Cuboid(bottom.cell, center=[-bond * 1e-4] * 3) # Reduce atoms in bottom inside_idx = cell_box.within_index(bottom.xyz) bottom = bottom.sub(inside_idx) # Rotate top layer around A atom in bottom layer - top = (top - .tile(rep, axis=0) - .tile(rep, axis=1) - .move(align_vec) - .rotate(theta, [0, 0, 1], what="abc+xyz")) + top = ( + top.tile(rep, axis=0) + .tile(rep, axis=1) + .move(align_vec) + .rotate(theta, [0, 0, 1], what="abc+xyz") + ) inside_idx = cell_box.within_index(top.xyz) top = top.sub(inside_idx) @@ -135,11 +135,11 @@ def bilayer(bond: float=1.42, # Which layers to be returned layer = layer.lower() - if layer == 'bottom': + if layer == "bottom": bilayer = bottom - elif layer == 'top': + elif layer == "top": bilayer = top - elif layer == 'both': + elif layer == "both": bilayer = bottom.add(top) natoms *= 2 else: @@ -148,12 +148,12 @@ def bilayer(bond: float=1.42, if rep > 1: # Rotate and shift unit cell back fxyz_min = bilayer.fxyz.min(axis=0) - fxyz_min[2] = 0. + fxyz_min[2] = 0.0 # This is a small hack since rotate is not numerically # precise. # TODO We could consider using mpfmath in Quaternion for increased # precision... - fxyz_min[np.fabs(fxyz_min) > 1.e-7] *= 0.49 + fxyz_min[np.fabs(fxyz_min) > 1.0e-7] *= 0.49 offset = fxyz_min.dot(bilayer.cell) vec = bilayer.cell[0] + bilayer.cell[1] vec_costh = vec[0] / vec.dot(vec) ** 0.5 diff --git a/src/sisl/geom/category/_coord.py b/src/sisl/geom/category/_coord.py index aea0128149..2ee11dcc8a 100644 --- a/src/sisl/geom/category/_coord.py +++ b/src/sisl/geom/category/_coord.py @@ -23,7 +23,7 @@ @set_module("sisl.geom") class AtomFracSite(AtomCategory): - r""" Classify atoms based on fractional sites for a given supercell + r"""Classify atoms based on fractional sites for a given supercell Match atomic coordinates based on the fractional positions. @@ -54,10 +54,16 @@ class AtomFracSite(AtomCategory): ... else: ... assert c == B_site """ - __slots__ = (f"_{a}" for a in ("cell", "icell", "length", "atol", "offset", "foffset")) - - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") - def __init__(self, lattice, atol=1.e-5, offset=(0., 0., 0.), foffset=(0., 0., 0.)): + __slots__ = ( + f"_{a}" for a in ("cell", "icell", "length", "atol", "offset", "foffset") + ) + + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) + def __init__( + self, lattice, atol=1.0e-5, offset=(0.0, 0.0, 0.0), foffset=(0.0, 0.0, 0.0) + ): if isinstance(lattice, LatticeChild): lattice = lattice.lattice elif not isinstance(lattice, Lattice): @@ -76,27 +82,36 @@ def __init__(self, lattice, atol=1.e-5, offset=(0., 0., 0.), foffset=(0., 0., 0. # fractional offset before comparing to the integer part of the fractional coordinate self._foffset = _a.arrayd(foffset).reshape(1, 3) - super().__init__(f"fracsite(atol={self._atol}, offset={self._offset}, foffset={self._foffset})") + super().__init__( + f"fracsite(atol={self._atol}, offset={self._offset}, foffset={self._foffset})" + ) def categorize(self, geometry, atoms=None): # _sanitize_loop will ensure that atoms will always be an integer if atoms is None: fxyz = np.dot(geometry.xyz + self._offset, self._icell.T) + self._foffset else: - fxyz = np.dot(geometry.xyz[atoms].reshape(-1, 3) + self._offset, - self._icell.T) + self._foffset + fxyz = ( + np.dot(geometry.xyz[atoms].reshape(-1, 3) + self._offset, self._icell.T) + + self._foffset + ) # Find fractional indices that match to an integer of the passed cell # We multiply with the length of the cell to get an error in Ang - ret = np.where(np.fabs((fxyz - np.rint(fxyz))*self._length).max(1) <= self._atol, - self, NullCategory()).tolist() + ret = np.where( + np.fabs((fxyz - np.rint(fxyz)) * self._length).max(1) <= self._atol, + self, + NullCategory(), + ).tolist() if isinstance(atoms, Integral): ret = ret[0] return ret def __eq__(self, other): if self.__class__ is other.__class__: - for s, o in map(lambda a: (getattr(self, f"_{a}"), getattr(other, f"_{a}")), - ("cell", "icell", "atol", "offset", "foffset")): + for s, o in map( + lambda a: (getattr(self, f"_{a}"), getattr(other, f"_{a}")), + ("cell", "icell", "atol", "offset", "foffset"), + ): if not np.allclose(s, o): return False return True @@ -105,17 +120,17 @@ def __eq__(self, other): @set_module("sisl.geom") class AtomXYZ(AtomCategory): - r""" Classify atoms based on coordinates + r"""Classify atoms based on coordinates Parameters ---------- *args : Shape any shape that implements `Shape.within` - **kwargs: + **kwargs: keys are operator specifications and values are used in those specifications. The keys are split into 3 sections - ``__`` + ``__`` - ``options`` are made of combinations of ``['a', 'f']`` i.e. ``"af"``, ``"f"`` or ``"a"`` are all valid. An ``a`` takes the absolute value, ``f`` means a fractional @@ -147,27 +162,34 @@ class AtomXYZ(AtomCategory): __slots__ = ("_coord_check",) def __init__(self, *args, **kwargs): - def create1(is_frac, is_abs, op, d): if is_abs: + @wraps(op) def func(a, b): return op(np.fabs(a)) + else: + @wraps(op) def func(a, b): return op(a) + return is_frac, func, d, None def create2(is_frac, is_abs, op, d, val): if is_abs: + @wraps(op) def func(a, b): return op(np.fabs(a), b) + else: + @wraps(op) def func(a, b): return op(a, b) + return is_frac, func, d, val coord_ops = [] @@ -175,8 +197,10 @@ def func(a, b): # For each *args we expect this to be a shape for arg in args: if not isinstance(arg, Shape): - raise ValueError(f"{self.__class__.__name__} requires non-keyword arguments " - f"to be of type Shape {type(arg)}.") + raise ValueError( + f"{self.__class__.__name__} requires non-keyword arguments " + f"to be of type Shape {type(arg)}." + ) coord_ops.append(create1(False, False, arg.within, (0, 1, 2))) @@ -198,8 +222,10 @@ def func(a, b): elif value.size == 2: sdir = key else: - raise ValueError(f"{self.__class__.__name__} could not determine the operations for {key}={value}.\n" - f"{key} must be on the form [fa]__") + raise ValueError( + f"{self.__class__.__name__} could not determine the operations for {key}={value}.\n" + f"{key} must be on the form [fa]__" + ) # parse options is_abs = "a" in spec @@ -211,11 +237,17 @@ def func(a, b): if value.size == 2: # do it twice if not value[0] is None: - coord_ops.append(create2(is_frac, is_abs, operator.ge, sdir, value[0])) + coord_ops.append( + create2(is_frac, is_abs, operator.ge, sdir, value[0]) + ) if not value[1] is None: - coord_ops.append(create2(is_frac, is_abs, operator.le, sdir, value[1])) + coord_ops.append( + create2(is_frac, is_abs, operator.le, sdir, value[1]) + ) else: - coord_ops.append(create2(is_frac, is_abs, getattr(operator, op), sdir, value)) + coord_ops.append( + create2(is_frac, is_abs, getattr(operator, op), sdir, value) + ) self._coord_check = coord_ops super().__init__("coord") @@ -227,14 +259,18 @@ def categorize(self, geometry, atoms=None): else: xyz = geometry.xyz[atoms] fxyz = geometry.fxyz[atoms] + def call(frac, func, d, val): if frac: return func(fxyz[..., d], val) return func(xyz[..., d], val) and_reduce = np.logical_and.reduce - ret = np.where(and_reduce([call(*four) for four in self._coord_check]), - self, NullCategory()).tolist() + ret = np.where( + and_reduce([call(*four) for four in self._coord_check]), + self, + NullCategory(), + ).tolist() if isinstance(atoms, Integral): ret = ret[0] return ret @@ -286,8 +322,10 @@ def _new(cls, *interval, **kwargs): """ Will go into the __new__ method of the coordinate classes """ + def _apply_key(k, v): return f"{key}_{k}", v + if len(kwargs) > 0: new_kwargs = dict(map(_apply_key, *zip(*kwargs.items()))) else: @@ -303,21 +341,25 @@ def _apply_key(k, v): # error for multiple entries return AtomXYZ(**{key: interval}, **new_kwargs) elif len(interval) != 0: - raise ValueError(f"{cls.__name__} non-keyword argumest must be 1 tuple, or 2 values") + raise ValueError( + f"{cls.__name__} non-keyword argumest must be 1 tuple, or 2 values" + ) return AtomXYZ(**new_kwargs) return _new + # Iterate over all directions that deserve an individual class for key in ("x", "y", "z", "f_x", "f_y", "f_z", "a_x", "a_y", "a_z"): - # The underscores are just kept for the key that is passed to AtomXYZ, # but the name of the category will have no underscore name = key.replace("_", "") # Create the class for this direction # note this is lower case since AtomZ should not interfere with Atomz - new_cls = AtomXYZMeta(f"Atom{name}", (AtomCategory, ), {"__new__": _new_factory(key)}) + new_cls = AtomXYZMeta( + f"Atom{name}", (AtomCategory,), {"__new__": _new_factory(key)} + ) new_cls = set_module("sisl.geom")(new_cls) diff --git a/src/sisl/geom/category/_kind.py b/src/sisl/geom/category/_kind.py index 40afa95b63..b04d198a2b 100644 --- a/src/sisl/geom/category/_kind.py +++ b/src/sisl/geom/category/_kind.py @@ -19,7 +19,7 @@ @set_module("sisl.geom") class AtomZ(AtomCategory): - r""" Classify atoms based on atomic number + r"""Classify atoms based on atomic number Parameters ---------- @@ -55,7 +55,7 @@ def __eq__(self, other): @set_module("sisl.geom") class AtomTag(AtomCategory): - r""" Classify atoms based on their tag. + r"""Classify atoms based on their tag. Parameters ---------- @@ -88,7 +88,7 @@ def __eq__(self, other): @set_module("sisl.geom") class AtomIndex(AtomCategory): - r""" Classify atoms based on indices + r"""Classify atoms based on indices Parameters ---------- @@ -134,20 +134,23 @@ def __init__(self, *args, **kwargs): def wrap_func(func): @wraps(func) def make_partial(a, b): - """ Wrapper to make partial useful """ + """Wrapper to make partial useful""" if isinstance(b, Integral): return op.truth(func(a, b)) is_true = True for ib in b: is_true = is_true and func(a, ib) return is_true + return make_partial operator = [] if len(idx) > 0: + @wraps(op.contains) def func_wrap(a, b): return op.contains(b, a) + operator.append((func_wrap, idx)) for func, value in kwargs.items(): @@ -166,7 +169,9 @@ def func_wrap(a, b): # Create string self._op_val = operator - super().__init__(" & ".join(map(lambda f, b: f"{f.__name__}[{b}]", *zip(*self._op_val)))) + super().__init__( + " & ".join(map(lambda f, b: f"{f.__name__}[{b}]", *zip(*self._op_val))) + ) @_sanitize_loop def categorize(self, geometry, atoms=None): @@ -179,13 +184,15 @@ def __eq__(self, other): if self.__class__ is other.__class__: if len(self._op_val) == len(other._op_val): # Check they are the same - return reduce(op.and_, (op_val in other._op_val for op_val in self._op_val), True) + return reduce( + op.and_, (op_val in other._op_val for op_val in self._op_val), True + ) return False @set_module("sisl.geom") class AtomSeq(AtomIndex): - r""" Classify atoms based on their indices using a sequence string. + r"""Classify atoms based on their indices using a sequence string. Parameters ---------- @@ -229,6 +236,7 @@ def _sanitize_negs(indices_map, end): end: int The largest valid index. """ + def _sanitize(item): if isinstance(item, int): if item < 0: @@ -263,7 +271,7 @@ def __eq__(self, other): class AtomEven(AtomCategory): - r""" Classify atoms based on indices (even in this case)""" + r"""Classify atoms based on indices (even in this case)""" __slots__ = [] def __init__(self): @@ -282,7 +290,7 @@ def __eq__(self, other): @set_module("sisl.geom") class AtomOdd(AtomCategory): - r""" Classify atoms based on indices (odd in this case)""" + r"""Classify atoms based on indices (odd in this case)""" __slots__ = [] def __init__(self): diff --git a/src/sisl/geom/category/_neighbours.py b/src/sisl/geom/category/_neighbours.py index f11ce1f77e..26d28f9e47 100644 --- a/src/sisl/geom/category/_neighbours.py +++ b/src/sisl/geom/category/_neighbours.py @@ -12,7 +12,7 @@ @set_module("sisl.geom") class AtomNeighbours(AtomCategory): - r""" Classify atoms based on number of neighbours + r"""Classify atoms based on number of neighbours Parameters ---------- @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): *args, kwargs["neighbour"] = args self._min = 0 - self._max = 2 ** 31 + self._max = 2**31 if len(args) == 1: self._min = args[0] @@ -63,7 +63,7 @@ def __init__(self, *args, **kwargs): if self._min == self._max: name = f"={self._max}" - elif self._max == 2 ** 31: + elif self._max == 2**31: name = f" ∈ [{self._min};∞[" else: name = f" ∈ [{self._min};{self._max}]" @@ -91,7 +91,7 @@ def R(self, atom): @_sanitize_loop def categorize(self, geometry, atoms=None): - """ Check if geometry and atoms matches the neighbour criteria """ + """Check if geometry and atoms matches the neighbour criteria""" idx = geometry.close(atoms, R=self.R(geometry.atoms[atoms]))[1] if len(idx) < self._min: return NullCategory() @@ -100,8 +100,7 @@ def categorize(self, geometry, atoms=None): if not self._in is None: # Get category of neighbours cat = self._in.categorize(geometry, geometry.asc2uc(idx)) - idx = [i for i, c in zip(idx, cat) - if not isinstance(c, NullCategory)] + idx = [i for i, c in zip(idx, cat) if not isinstance(c, NullCategory)] n = len(idx) if self._min <= n and n <= self._max: return self @@ -110,7 +109,9 @@ def categorize(self, geometry, atoms=None): def __eq__(self, other): eq = self.__class__ is other.__class__ if eq: - return (self._min == other._min and - self._max == other._max and - self._in == other._in) + return ( + self._min == other._min + and self._max == other._max + and self._in == other._in + ) return False diff --git a/src/sisl/geom/category/base.py b/src/sisl/geom/category/base.py index 6f5604de63..efa5e4e82a 100644 --- a/src/sisl/geom/category/base.py +++ b/src/sisl/geom/category/base.py @@ -19,8 +19,10 @@ def loop_func(self, geometry, atoms=None): if atoms.ndim == 0: return func(self, geometry, atoms) return [func(self, geometry, ia) for ia in atoms] + return loop_func -#class AtomCategory(Category) + +# class AtomCategory(Category) # is defined in sisl/geometry.py since it is required in # that instance. diff --git a/src/sisl/geom/category/tests/test_geom_category.py b/src/sisl/geom/category/tests/test_geom_category.py index a940d82298..4c9214a2c2 100644 --- a/src/sisl/geom/category/tests/test_geom_category.py +++ b/src/sisl/geom/category/tests/test_geom_category.py @@ -133,8 +133,13 @@ def test_geom_category_seq(): def test_geom_category_tag(): - atoms = [Atom(Z=6, tag="C1"), Atom(Z=6, tag="C2"), Atom(Z=6, tag="C3"), Atom(Z=1, tag="H")] - geom = Geometry([[0, 0, 0]]*4, atoms=atoms) + atoms = [ + Atom(Z=6, tag="C1"), + Atom(Z=6, tag="C2"), + Atom(Z=6, tag="C3"), + Atom(Z=1, tag="H"), + ] + geom = Geometry([[0, 0, 0]] * 4, atoms=atoms) cat = AtomTag("") assert len(geom.asc2uc(cat)) == 4 @@ -186,7 +191,7 @@ def test_geom_category_frac_A_B_site(): gr = graphene() * (4, 5, 1) A_site = AtomFracSite(graphene()) - B_site = AtomFracSite(graphene(), foffset=(-1/3, -1/3, 0)) + B_site = AtomFracSite(graphene(), foffset=(-1 / 3, -1 / 3, 0)) cat = (A_site | B_site).categorize(gr) for i, c in enumerate(cat): @@ -251,7 +256,6 @@ def test_geom_category_xyz_meta(): # Check that all classes work for key in ("x", "y", "z", "f_x", "f_y", "f_z", "a_x", "a_y", "a_z"): - name = key.replace("_", "") # Check that the attribute is present @@ -274,5 +278,7 @@ def get_cls(op, v): assert np.all(sc2uc(cls >= 0.5) == sc2uc(cls(ge=0.5))) assert np.all(sc2uc(cls >= 0.5) == sc2uc(get_cls("ge", 0.5))) - assert np.all(sc2uc(cls((-1, 1))) == sc2uc(get_cls("ge", -1) & get_cls("le", 1))) + assert np.all( + sc2uc(cls((-1, 1))) == sc2uc(get_cls("ge", -1) & get_cls("le", 1)) + ) assert np.all(sc2uc(cls(-1, 1)) == sc2uc(get_cls("ge", -1) & get_cls("le", 1))) diff --git a/src/sisl/geom/flat.py b/src/sisl/geom/flat.py index 44bc91aef7..2610125048 100644 --- a/src/sisl/geom/flat.py +++ b/src/sisl/geom/flat.py @@ -8,14 +8,12 @@ from ._common import geometry_define_nsc -__all__ = ['honeycomb', 'graphene', 'honeycomb_flake', 'graphene_flake'] +__all__ = ["honeycomb", "graphene", "honeycomb_flake", "graphene_flake"] @set_module("sisl.geom") -def honeycomb(bond: float, - atoms, - orthogonal: bool=False): - """ Honeycomb lattice with 2 or 4 atoms per unit-cell, latter orthogonal cell +def honeycomb(bond: float, atoms, orthogonal: bool = False): + """Honeycomb lattice with 2 or 4 atoms per unit-cell, latter orthogonal cell This enables creating BN lattices with ease, or graphene lattices. @@ -33,32 +31,42 @@ def honeycomb(bond: float, graphene: the equivalent of this, but with default of Carbon atoms bilayer: create bilayer honeycomb lattices """ - sq3h = 3.**.5 * 0.5 + sq3h = 3.0**0.5 * 0.5 if orthogonal: - lattice = Lattice(np.array([[3., 0., 0.], - [0., 2 * sq3h, 0.], - [0., 0., 10.]], np.float64) * bond) - g = Geometry(np.array([[0., 0., 0.], - [0.5, sq3h, 0.], - [1.5, sq3h, 0.], - [2., 0., 0.]], np.float64) * bond, - atoms, lattice=lattice) + lattice = Lattice( + np.array( + [[3.0, 0.0, 0.0], [0.0, 2 * sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond + ) + g = Geometry( + np.array( + [[0.0, 0.0, 0.0], [0.5, sq3h, 0.0], [1.5, sq3h, 0.0], [2.0, 0.0, 0.0]], + np.float64, + ) + * bond, + atoms, + lattice=lattice, + ) else: - lattice = Lattice(np.array([[1.5, -sq3h, 0.], - [1.5, sq3h, 0.], - [0., 0., 10.]], np.float64) * bond) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms, lattice=lattice) + lattice = Lattice( + np.array( + [[1.5, -sq3h, 0.0], [1.5, sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond + ) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms, + lattice=lattice, + ) geometry_define_nsc(g, [True, True, False]) return g @set_module("sisl.geom") -def graphene(bond: float=1.42, - atoms=None, - orthogonal: bool=False): - """ Graphene lattice with 2 or 4 atoms per unit-cell, latter orthogonal cell +def graphene(bond: float = 1.42, atoms=None, orthogonal: bool = False): + """Graphene lattice with 2 or 4 atoms per unit-cell, latter orthogonal cell Parameters ---------- @@ -81,10 +89,7 @@ def graphene(bond: float=1.42, @set_module("sisl.geom") -def honeycomb_flake(shells: int, - bond: float, - atoms, - vacuum: float = 20.) -> Geometry: +def honeycomb_flake(shells: int, bond: float, atoms, vacuum: float = 20.0) -> Geometry: """Hexagonal flake of a honeycomb lattice, with zig-zag edges. Parameters @@ -104,7 +109,6 @@ def honeycomb_flake(shells: int, # hexagonal flake. The rest of the portions are obtained by rotating # this one by 60, 120, 180, 240 and 300 degrees. def _minimal_op(shells): - # The function is based on the horizontal lines of the hexagon, # which are made of a pair of atoms. # For each shell, we first need to complete the incomplete horizontal @@ -112,8 +116,8 @@ def _minimal_op(shells): # the next horizontal lines. # Displacement from the end of one horizontal pair to the beggining of the next - branch_displ_x = bond * 0.5 # cos(60) = 0.5 - branch_displ_y = bond * 3 ** 0.5 / 2 # sin(60) = sqrt(3)/2 + branch_displ_x = bond * 0.5 # cos(60) = 0.5 + branch_displ_y = bond * 3**0.5 / 2 # sin(60) = sqrt(3)/2 # Iterate over shells. We also keep track of the atom types, in case # we have two different atoms in the honeycomb lattice. @@ -122,44 +126,48 @@ def _minimal_op(shells): for shell in range(shells): n_new_branches = 2 + shell prev_edge = branch_displ_y * (shell) - + sat = np.zeros((shell + 1, 3)) sat[:, 0] = op[-1, 0] + bond sat[:, 1] = np.linspace(-prev_edge, prev_edge, shell + 1) - + edge = branch_displ_y * (shell + 1) branches = np.zeros((n_new_branches, 3)) branches[:, 0] = sat[0, 0] + branch_displ_x - branches[:, 1] = np.linspace(-edge, edge, n_new_branches ) + branches[:, 1] = np.linspace(-edge, edge, n_new_branches) op = np.concatenate([op, sat, branches]) - types = np.concatenate([types, np.full(len(sat), 1), np.full(len(branches), 0)]) + types = np.concatenate( + [types, np.full(len(sat), 1), np.full(len(branches), 0)] + ) return op, types - + # Get the coordinates of 1/6 of the hexagon for the requested number of shells. op, types = _minimal_op(shells) single_atom_type = isinstance(atoms, (str, Atom)) or len(atoms) == 1 # Create a geometry from the coordinates. - ats = atoms if single_atom_type else np.asarray(atoms)[types] + ats = atoms if single_atom_type else np.asarray(atoms)[types] geom = Geometry(op, atoms=ats) # The second portion of the hexagon is obtained by rotating the first one by 60 degrees. # However, if there are two different atoms in the honeycomb lattice, we need to reverse the types. - next_triangle = geom if single_atom_type else Geometry(op, atoms=np.asarray(atoms)[types - 1]) - geom += next_triangle.rotate(60, [0,0,1]) + next_triangle = ( + geom if single_atom_type else Geometry(op, atoms=np.asarray(atoms)[types - 1]) + ) + geom += next_triangle.rotate(60, [0, 0, 1]) # Then just rotate the two triangles by 120 and 240 degrees to get the full hexagon. - geom += geom.rotate(120, [0,0,1]) + geom.rotate(240, [0,0,1]) + geom += geom.rotate(120, [0, 0, 1]) + geom.rotate(240, [0, 0, 1]) # Set the cell according to the requested vacuum - max_x = np.max(geom.xyz[:,0]) - geom.cell[0,0] = max_x * 2 + vacuum - geom.cell[1,1] = max_x * 2 + vacuum - geom.cell[2,2] = 20. + max_x = np.max(geom.xyz[:, 0]) + geom.cell[0, 0] = max_x * 2 + vacuum + geom.cell[1, 1] = max_x * 2 + vacuum + geom.cell[2, 2] = 20.0 # Center the flake geom = geom.translate(geom.center(what="cell")) @@ -171,10 +179,9 @@ def _minimal_op(shells): @set_module("sisl.geom") -def graphene_flake(shells: int, - bond: float = 1.42, - atoms=None, - vacuum: float = 20.) -> Geometry: +def graphene_flake( + shells: int, bond: float = 1.42, atoms=None, vacuum: float = 20.0 +) -> Geometry: """Hexagonal flake of graphene, with zig-zag edges. Parameters diff --git a/src/sisl/geom/nanoribbon.py b/src/sisl/geom/nanoribbon.py index ca9fb06a08..1543a386e9 100644 --- a/src/sisl/geom/nanoribbon.py +++ b/src/sisl/geom/nanoribbon.py @@ -13,17 +13,18 @@ from ._composite import CompositeGeometrySection, composite_geometry __all__ = [ - 'nanoribbon', 'graphene_nanoribbon', 'agnr', 'zgnr', - 'heteroribbon', 'graphene_heteroribbon', + "nanoribbon", + "graphene_nanoribbon", + "agnr", + "zgnr", + "heteroribbon", + "graphene_heteroribbon", ] @set_module("sisl.geom") -def nanoribbon(width: int, - bond: float, - atoms, - kind: str='armchair'): - r""" Construction of a nanoribbon unit cell of type armchair or zigzag. +def nanoribbon(width: int, bond: float, atoms, kind: str = "armchair"): + r"""Construction of a nanoribbon unit cell of type armchair or zigzag. The geometry is oriented along the :math:`x` axis. @@ -75,7 +76,9 @@ def nanoribbon(width: int, # Invert y-coordinates ribbon.xyz[:, 1] *= -1 # Set lattice vectors strictly orthogonal - ribbon.cell[:, :] = np.diag([ribbon.cell[1, 0], -ribbon.cell[0, 1], ribbon.cell[2, 2]]) + ribbon.cell[:, :] = np.diag( + [ribbon.cell[1, 0], -ribbon.cell[0, 1], ribbon.cell[2, 2]] + ) # Sort along x, then y ribbon = ribbon.sort(axis=(0, 1)) @@ -83,7 +86,7 @@ def nanoribbon(width: int, raise ValueError(f"nanoribbon: kind must be armchair or zigzag ({kind})") # Separate ribbons along y-axis - ribbon.cell[1, 1] += 20. + ribbon.cell[1, 1] += 20.0 # Move inside unit cell xyz = ribbon.xyz.min(axis=0) * [1, 1, 0] @@ -94,11 +97,10 @@ def nanoribbon(width: int, @set_module("sisl.geom") -def graphene_nanoribbon(width: int, - bond: float=1.42, - atoms=None, - kind: str='armchair'): - r""" Construction of a graphene nanoribbon +def graphene_nanoribbon( + width: int, bond: float = 1.42, atoms=None, kind: str = "armchair" +): + r"""Construction of a graphene nanoribbon Parameters ---------- @@ -125,10 +127,8 @@ def graphene_nanoribbon(width: int, @set_module("sisl.geom") -def agnr(width: int, - bond: float=1.42, - atoms=None): - r""" Construction of an armchair graphene nanoribbon +def agnr(width: int, bond: float = 1.42, atoms=None): + r"""Construction of an armchair graphene nanoribbon Parameters ---------- @@ -147,14 +147,12 @@ def agnr(width: int, graphene_nanoribbon : graphene nanoribbon zgnr : zigzag graphene nanoribbon """ - return graphene_nanoribbon(width, bond, atoms, kind='armchair') + return graphene_nanoribbon(width, bond, atoms, kind="armchair") @set_module("sisl.geom") -def zgnr(width: int, - bond: float=1.42, - atoms=None): - r""" Construction of a zigzag graphene nanoribbon +def zgnr(width: int, bond: float = 1.42, atoms=None): + r"""Construction of a zigzag graphene nanoribbon Parameters ---------- @@ -173,7 +171,7 @@ def zgnr(width: int, graphene_nanoribbon : graphene nanoribbon agnr : armchair graphene nanoribbon """ - return graphene_nanoribbon(width, bond, atoms, kind='zigzag') + return graphene_nanoribbon(width, bond, atoms, kind="zigzag") @set_module("sisl.geom") @@ -211,7 +209,7 @@ class _heteroribbon_section(CompositeGeometrySection): If ``False``, sections are just shifted (`shift`) number of atoms. - If ``True``, shifts are quantized in the sense that shifts that produce + If ``True``, shifts are quantized in the sense that shifts that produce lone atoms (< 2 neighbours) are ignored. Then: - ``shift = 0`` means aligned. - ``shift = -1`` means first possible downwards shift (if available). @@ -219,13 +217,14 @@ class _heteroribbon_section(CompositeGeometrySection): If this is set to `True`, `on_lone_atom` is overwritten to `"raise"`. on_lone_atom: {'ignore', 'warn', 'raise'} What to do when a junction between sections produces lone atoms (< 2 neighbours) - + Messages contain hopefully useful explanations to understand what to do to fix it. invert_first: bool, optional - Whether, if this is the first section, it should be inverted with respect + Whether, if this is the first section, it should be inverted with respect to the one provided by `sisl.geom.nanoribbon`. """ + W: int L: int = 1 shift: int = 0 @@ -247,7 +246,7 @@ def __post_init__(self): if self.shift_quantum: self.on_lone_atom = "raise" - + self._open_borders = [False, False] def _shift_unit_cell(self, geometry): @@ -256,7 +255,7 @@ def _shift_unit_cell(self, geometry): It does so by shifting half unit cell. This must be done before any tiling of the geometry. """ - move = np.array([0., 0., 0.]) + move = np.array([0.0, 0.0, 0.0]) move[self.long_ax] = -geometry.xyz[self.W, self.long_ax] geometry = geometry.move(move) geometry.xyz = (geometry.fxyz % 1).dot(geometry.cell) @@ -271,11 +270,11 @@ def _align_offset(self, prev, new_xyz): align = self.align.lower() if prev is None: return align, 0 - + W = self.W W_diff = W - prev.W if align in ("a", "auto"): - if (W % 2 == 1 and W_diff % 2 == 0): + if W % 2 == 1 and W_diff % 2 == 0: # Both ribbons are odd, so we align on the center align = "c" elif prev.W % 2 == 0: @@ -289,15 +288,30 @@ def _align_offset(self, prev, new_xyz): if align in ("c", "center"): if W_diff % 2 == 1: - self._junction_error(prev, "Different parity sections can not be aligned by their centers", "raise") - return align, prev.xyz[:, self.trans_ax].mean() - new_xyz[:, self.trans_ax].mean() + self._junction_error( + prev, + "Different parity sections can not be aligned by their centers", + "raise", + ) + return ( + align, + prev.xyz[:, self.trans_ax].mean() - new_xyz[:, self.trans_ax].mean(), + ) elif align in ("t", "top"): - return align, prev.xyz[:, self.trans_ax].max() - new_xyz[:, self.trans_ax].max() + return ( + align, + prev.xyz[:, self.trans_ax].max() - new_xyz[:, self.trans_ax].max(), + ) elif align in ("b", "bottom"): - return align, prev.xyz[:, self.trans_ax].min() - new_xyz[:, self.trans_ax].min() + return ( + align, + prev.xyz[:, self.trans_ax].min() - new_xyz[:, self.trans_ax].min(), + ) else: - raise ValueError(f"Invalid value for 'align': {align}. Must be one of" - " {'c', 'center', 't', 'top', 'b', 'bottom', 'a', 'auto'}") + raise ValueError( + f"Invalid value for 'align': {align}. Must be one of" + " {'c', 'center', 't', 'top', 'b', 'bottom', 'a', 'auto'}" + ) def _offset_from_center(self, align, prev): align = align.lower()[0] @@ -318,25 +332,30 @@ def _offset_from_center(self, align, prev): def build_section(self, prev): new_section = nanoribbon( - bond=self.bond, atoms=self.atoms, - width=self.W, kind=self.kind + bond=self.bond, atoms=self.atoms, width=self.W, kind=self.kind ) align, offset = self._align_offset(prev, new_section) if prev is not None: if not isinstance(prev, _heteroribbon_section): - self._junction_error(prev, f"{self.__class__.__name__} can not be appended to {type(prev).__name__}", "raise") + self._junction_error( + prev, + f"{self.__class__.__name__} can not be appended to {type(prev).__name__}", + "raise", + ) if self.kind != prev.kind: self._junction_error(prev, f"Ribbons must be of same type.", "raise") if self.bond != prev.bond: - self._junction_error(prev, f"Ribbons must have same bond length.", "raise") + self._junction_error( + prev, f"Ribbons must have same bond length.", "raise" + ) shift = self._parse_shift(self.shift, prev, align) - + # Get the distance of an atom shift. (sin(60) = 3**.5 / 2) - atom_shift = self.bond * 3**.5 / 2 - + atom_shift = self.bond * 3**0.5 / 2 + # if (last_W % 2 == 1 and W < last_W) and last_open: # _junction_error(i, "DANGLING BONDS: Previous odd section, which has an open end," # " is wider than the incoming one. A wider odd section must always" @@ -363,7 +382,7 @@ def build_section(self, prev): if aligned_match == (shift % 2 == 1): new_section = self._shift_unit_cell(new_section) self._open_borders[0] = not self._open_borders[0] - + # Apply the offset that we have calculated. move = np.zeros(3) move[self.trans_ax] = offset + shift * atom_shift @@ -383,11 +402,17 @@ def build_section(self, prev): # Cut the last string of atoms. if cut_last: new_section.cell[0, 0] *= self.L / (self.L + 1) - new_section = new_section.remove({"xy"[self.long_ax]: - (new_section.cell[self.long_ax, self.long_ax] - 0.01, None)}) + new_section = new_section.remove( + { + "xy"[self.long_ax]: ( + new_section.cell[self.long_ax, self.long_ax] - 0.01, + None, + ) + } + ) self._open_borders[1] = self._open_borders[0] != cut_last - + self.xyz = new_section.xyz return new_section @@ -396,14 +421,16 @@ def add_section(self, geom, new_section): new_min = new_section[:, self.trans_ax].min() new_max = new_section[:, self.trans_ax].max() if new_min < 0: - cell_offset = - new_min + 14 + cell_offset = -new_min + 14 geom = geom.add_vacuum(cell_offset, self.trans_ax) move = np.zeros(3) move[self.trans_ax] = cell_offset geom = geom.move(move) new_section = new_section.move(move) if new_max > geom.cell[1, 1]: - geom = geom.add_vacuum(new_max - geom.cell[self.trans_ax, self.trans_ax] + 14, self.trans_ax) + geom = geom.add_vacuum( + new_max - geom.cell[self.trans_ax, self.trans_ax] + 14, self.trans_ax + ) self.xyz = new_section.xyz # Finally, we can safely append the geometry. @@ -419,12 +446,15 @@ def _parse_shift(self, shift, prev, align): # a smaller ribbon, since no matter what the shift is there will # always be dangling bonds. if (prev.W % 2 == 1 and W < prev.W) and prev._open_borders[1]: - self._junction_error(prev, "LONE ATOMS: Previous odd section, which has an open end," + self._junction_error( + prev, + "LONE ATOMS: Previous odd section, which has an open end," " is wider than the incoming one. A wider odd section must always" " have a closed end. You can solve this by making the previous section" - " one unit smaller or larger (L = L +- 1).", self.on_lone_atom + " one unit smaller or larger (L = L +- 1).", + self.on_lone_atom, ) - + # Get the difference in width between the previous and this ribbon section W_diff = W - prev.W # And also the mod(4) because there are certain differences if the width differs @@ -448,16 +478,19 @@ def _parse_shift(self, shift, prev, align): # the previous section. shift_lims = { "closed": prev.W // 2 + W // 2 - 2, - "open": prev.W // 2 - W // 2 - 1 + "open": prev.W // 2 - W // 2 - 1, } - shift_pars = { - lim % 2: lim for k, lim in shift_lims.items() - } + shift_pars = {lim % 2: lim for k, lim in shift_lims.items()} # Build an array with all the valid shifts. - valid_shifts = np.sort([*np.arange(0, shift_pars[0] + 1, 2), *np.arange(1, shift_pars[1] + 1, 2)]) - valid_shifts = np.array([*(np.flip(-valid_shifts)[:-1]), *valid_shifts]) + valid_shifts = np.sort( + [ + *np.arange(0, shift_pars[0] + 1, 2), + *np.arange(1, shift_pars[1] + 1, 2), + ] + ) + valid_shifts = np.array([*(np.flip(-valid_shifts)[:-1]), *valid_shifts]) # Update the valid shift limits if the sections are aligned on any of the edges. shift_offset = self._offset_from_center(align, prev) @@ -467,7 +500,12 @@ def _parse_shift(self, shift, prev, align): else: # At this point, we already know that the incoming section is wider and # therefore it MUST have a closed start, otherwise there will be dangling bonds. - if diff_mod == 2 and prev._open_borders[1] or diff_mod == 0 and not prev._open_borders[1]: + if ( + diff_mod == 2 + and prev._open_borders[1] + or diff_mod == 0 + and not prev._open_borders[1] + ): # In these cases, centers must match or differ by an even number of atoms. # And this is the limit for the shift from the center. shift_lim = ((W_diff // 2) // 2) * 2 @@ -485,7 +523,10 @@ def _parse_shift(self, shift, prev, align): shift_offset = self._offset_from_center(align, prev) # Apply the offsets and calculate the maximum and minimum shifts. - min_shift, max_shift = -shift_lim - shift_offset, shift_lim - shift_offset + min_shift, max_shift = ( + -shift_lim - shift_offset, + shift_lim - shift_offset, + ) valid_shifts = np.arange(min_shift, max_shift + 1, 2) else: @@ -507,7 +548,7 @@ def _parse_shift(self, shift, prev, align): max_shift = prev.W - 1 else: special_shifts = [0] - min_shift = - W + 1 + min_shift = -W + 1 max_shift = -1 elif W % 2 == 1: @@ -516,10 +557,10 @@ def _parse_shift(self, shift, prev, align): if prev._open_borders[1]: special_shifts = [prev.W - W] min_shift = prev.W - W - max_shift = prev.W - W + 1 + ((W - 2) // 2)*2 + max_shift = prev.W - W + 1 + ((W - 2) // 2) * 2 else: special_shifts = [0] - min_shift = -1 - ((W - 2) // 2)*2 + min_shift = -1 - ((W - 2) // 2) * 2 max_shift = -1 else: if prev._open_borders[1]: @@ -527,14 +568,14 @@ def _parse_shift(self, shift, prev, align): max_shift = prev.W - 2 else: max_shift = -1 - min_shift = - (W - 2) + min_shift = -(W - 2) else: # Last section was odd, incoming section is even. if prev._open_borders[1]: special_shifts = [0, prev.W - W] min_shift = None else: - min_shift = [1, - (W - 2)] + min_shift = [1, -(W - 2)] max_shift = [prev.W - 2, prev.W - W - 1] # We have gone over all possible situations, now just build the @@ -543,7 +584,7 @@ def _parse_shift(self, shift, prev, align): if isinstance(min_shift, int): valid_shifts.extend(np.arange(min_shift, max_shift + 1, 2)) elif isinstance(min_shift, list): - for (m, mx) in zip(min_shift, max_shift): + for m, mx in zip(min_shift, max_shift): valid_shifts.extend(np.arange(m, mx + 1, 2)) # Apply offset on shifts based on the actual alignment requested @@ -552,7 +593,7 @@ def _parse_shift(self, shift, prev, align): if align[0] == "t": shift_offset = W_diff elif align[0] == "c": - shift_offset = - self._offset_from_center("b", prev) + shift_offset = -self._offset_from_center("b", prev) valid_shifts = np.array(valid_shifts) + shift_offset # Finally, check if the provided shift value is valid or not. @@ -567,32 +608,37 @@ def _parse_shift(self, shift, prev, align): # What flip does is prioritize upwards shifts. # That is, if both "-1" and "1" shifts are valid, # "1" will be picked as the reference. - aligned_shift = n_valid_shifts - 1 - np.argmin(np.abs(np.flip(valid_shifts))) + aligned_shift = ( + n_valid_shifts - 1 - np.argmin(np.abs(np.flip(valid_shifts))) + ) # Calculate which index we really need to retrieve corrected_shift = aligned_shift + shift if corrected_shift < 0 or corrected_shift >= n_valid_shifts: - self._junction_error(prev, f"LONE ATOMS: Shift must be between {-aligned_shift}" + self._junction_error( + prev, + f"LONE ATOMS: Shift must be between {-aligned_shift}" f" and {n_valid_shifts - aligned_shift - 1}, but {shift} was provided.", - self.on_lone_atom + self.on_lone_atom, ) - + # And finally get the shift value shift = valid_shifts[corrected_shift] else: if shift not in valid_shifts: - self._junction_error(prev, f"LONE ATOMS: Shift must be one of {valid_shifts}" - f" but {shift} was provided.", self.on_lone_atom + self._junction_error( + prev, + f"LONE ATOMS: Shift must be one of {valid_shifts}" + f" but {shift} was provided.", + self.on_lone_atom, ) return shift - + @set_module("sisl.geom") -def heteroribbon(sections, - section_cls=_heteroribbon_section, - **kwargs): +def heteroribbon(sections, section_cls=_heteroribbon_section, **kwargs): """Build a nanoribbon consisting of several nanoribbons of different widths. This function uses `composite_geometry`, but defaulting to the usage @@ -631,15 +677,18 @@ def heteroribbon(sections, """ return composite_geometry(sections, section_cls=section_cls, **kwargs) + heteroribbon.section = _heteroribbon_section @set_module("sisl.geom") -def graphene_heteroribbon(sections, - section_cls=_heteroribbon_section, - bond: float=1.42, - atoms=None, - **kwargs): +def graphene_heteroribbon( + sections, + section_cls=_heteroribbon_section, + bond: float = 1.42, + atoms=None, + **kwargs, +): """Build a graphene nanoribbon consisting of several nanoribbons of different widths Please see `heteroribbon` for arguments, the only difference is that the `bond` and `atoms` @@ -651,6 +700,9 @@ def graphene_heteroribbon(sections, """ if atoms is None: atoms = Atom(Z=6, R=bond * 1.01) - return composite_geometry(sections, section_cls=section_cls, bond=bond, atoms=atoms, **kwargs) + return composite_geometry( + sections, section_cls=section_cls, bond=bond, atoms=atoms, **kwargs + ) + graphene_heteroribbon.section = _heteroribbon_section diff --git a/src/sisl/geom/nanotube.py b/src/sisl/geom/nanotube.py index 3d501dbc4f..3e2bc11c2b 100644 --- a/src/sisl/geom/nanotube.py +++ b/src/sisl/geom/nanotube.py @@ -10,14 +10,12 @@ from ._common import geometry_define_nsc -__all__ = ['nanotube'] +__all__ = ["nanotube"] @set_module("sisl.geom") -def nanotube(bond: float, - atoms=None, - chirality: Tuple[int, int]=(1, 1)): - """ Nanotube with user-defined chirality. +def nanotube(bond: float, atoms=None, chirality: Tuple[int, int] = (1, 1)): + """Nanotube with user-defined chirality. This routine is implemented as in `ASE`_ with some cosmetic changes. @@ -41,10 +39,10 @@ def nanotube(bond: float, else: sign = 1 - sq3 = 3.0 ** .5 + sq3 = 3.0**0.5 a = sq3 * bond l2 = n * n + m * m + n * m - l = l2 ** .5 + l = l2**0.5 def gcd(a, b): while a != 0: @@ -81,15 +79,15 @@ def gcd(a, b): nnq.append(j) if ichk == 0: - raise RuntimeError('not found p, q strange!!') + raise RuntimeError("not found p, q strange!!") if ichk >= 2: - raise RuntimeError('more than 1 pair p, q strange!!') + raise RuntimeError("more than 1 pair p, q strange!!") nnnp = nnp[0] nnnq = nnq[0] lp = nnnp * nnnp + nnnq * nnnq + nnnp * nnnq - r = a * lp ** .5 + r = a * lp**0.5 c = a * l t = sq3 * c / ndr @@ -105,7 +103,7 @@ def gcd(a, b): h1 = abs(t) / abs(np.sin(q3)) h2 = bond * np.sin((np.pi / 6.0) - q1) - xyz = np.empty([nn*2, 3], np.float64) + xyz = np.empty([nn * 2, 3], np.float64) for i in range(nn): ix = i * 2 diff --git a/src/sisl/geom/special.py b/src/sisl/geom/special.py index a0a0424054..8e4e78cf11 100644 --- a/src/sisl/geom/special.py +++ b/src/sisl/geom/special.py @@ -8,13 +8,12 @@ from ._common import geometry_define_nsc -__all__ = ['diamond'] +__all__ = ["diamond"] @set_module("sisl.geom") -def diamond(alat: float=3.57, - atoms=None): - """ Diamond lattice with 2 atoms in the unitcell +def diamond(alat: float = 3.57, atoms=None): + """Diamond lattice with 2 atoms in the unitcell Parameters ---------- @@ -23,14 +22,15 @@ def diamond(alat: float=3.57, atoms : Atom, optional atom in the lattice, may be one or two atoms. Default is Carbon """ - dist = alat * 3. ** .5 / 4 + dist = alat * 3.0**0.5 / 4 if atoms is None: atoms = Atom(Z=6, R=dist * 1.01) - lattice = Lattice(np.array([[0, 1, 1], - [1, 0, 1], - [1, 1, 0]], np.float64) * alat / 2) - dia = Geometry(np.array([[0, 0, 0], [1, 1, 1]], np.float64) * alat / 4, - atoms, lattice=lattice) + lattice = Lattice( + np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], np.float64) * alat / 2 + ) + dia = Geometry( + np.array([[0, 0, 0], [1, 1, 1]], np.float64) * alat / 4, atoms, lattice=lattice + ) geometry_define_nsc(dia) return dia diff --git a/src/sisl/geom/surfaces.py b/src/sisl/geom/surfaces.py index fbebcc7474..81039f211f 100644 --- a/src/sisl/geom/surfaces.py +++ b/src/sisl/geom/surfaces.py @@ -13,7 +13,7 @@ from ._common import geometry2uc, geometry_define_nsc -__all__ = ['fcc_slab', 'bcc_slab', 'rocksalt_slab'] +__all__ = ["fcc_slab", "bcc_slab", "rocksalt_slab"] def _layer2int(layer, periodicity): @@ -57,7 +57,7 @@ def _calc_info(start, end, layers, periodicity): layers = layers[:nlayers] elif start is None: # end is not none - layers = layers[end+1:] + layers[:end+1] + layers = layers[end + 1 :] + layers[: end + 1] layers = layers[-nlayers:] elif end is None: # start is not none @@ -71,7 +71,9 @@ def _calc_info(start, end, layers, periodicity): # + 2 to allow rotating (stacking * (nlayers // periodicity + 2)).index(layers) except ValueError: - raise NotImplementedError(f"Stacking faults are not implemented, requested {layers} with stacking {stacking}") + raise NotImplementedError( + f"Stacking faults are not implemented, requested {layers} with stacking {stacking}" + ) if start is None and end is None: # easy case, we just calculate one of them @@ -79,12 +81,16 @@ def _calc_info(start, end, layers, periodicity): elif start is not None: if _layer2int(layers[0], periodicity) != start: - raise ValueError(f"Passing both 'layers' and 'start' requires them to be conforming; found layers={layers} " - f"and start={'ABCDEF'[start]}") + raise ValueError( + f"Passing both 'layers' and 'start' requires them to be conforming; found layers={layers} " + f"and start={'ABCDEF'[start]}" + ) elif end is not None: if _layer2int(layers[-1], periodicity) != end: - raise ValueError(f"Passing both 'layers' and 'end' requires them to be conforming; found layers={layers} " - f"and end={'ABCDEF'[end]}") + raise ValueError( + f"Passing both 'layers' and 'end' requires them to be conforming; found layers={layers} " + f"and end={'ABCDEF'[end]}" + ) # a sanity check for the algorithm, should always hold! if start is not None: @@ -124,13 +130,13 @@ def _convert_miller(miller): def _slab_with_vacuum(func, *args, **kwargs): - """Function to wrap `func` with vacuum in between """ + """Function to wrap `func` with vacuum in between""" layers = kwargs.pop("layers") if layers is None or isinstance(layers, Integral): return None def is_vacuum(layer): - """ A vacuum is defined by one of these variables: + """A vacuum is defined by one of these variables: - None - ' ' @@ -139,20 +145,22 @@ def is_vacuum(layer): if layer is None: return True if isinstance(layer, str): - return layer == ' ' + return layer == " " if isinstance(layer, Integral): return layer == 0 return False # we are dealing either with a list of ints or str if isinstance(layers, str): - nvacuums = layers.count(' ') + nvacuums = layers.count(" ") if nvacuums == 0: return None - if layers.count(' ') > 0: - raise ValueError("Denoting several vacuum layers next to each other is not supported. " - "Please pass 'vacuum' as an array instead.") + if layers.count(" ") > 0: + raise ValueError( + "Denoting several vacuum layers next to each other is not supported. " + "Please pass 'vacuum' as an array instead." + ) # determine number of slabs nslabs = len(layers.strip().split()) @@ -163,9 +171,12 @@ def are_layers(a, b): a_layer = not is_vacuum(a) b_layer = not is_vacuum(b) return a_layer and b_layer + # convert list correctly - layers = [[p, None] if are_layers(p, n) else [p] - for p, n in zip(layers[:-1], layers[1:])] + [[layers[-1]]] + layers = [ + [p, None] if are_layers(p, n) else [p] + for p, n in zip(layers[:-1], layers[1:]) + ] + [[layers[-1]]] layers = [l for ls in layers for l in ls] nvacuums = sum([1 if is_vacuum(l) else 0 for l in layers]) nslabs = sum([0 if is_vacuum(l) else 1 for l in layers]) @@ -179,23 +190,26 @@ def ensure_length(var, nslabs, name): return [var] * nslabs if len(var) > nslabs: - raise ValueError(f"Specification of {name} has too many elements compared to the " - f"number of slabs {nslabs}, please reduce length from {len(var)}.") + raise ValueError( + f"Specification of {name} has too many elements compared to the " + f"number of slabs {nslabs}, please reduce length from {len(var)}." + ) # it must be an array of some sorts out = [None] * nslabs - out[:len(var)] = var[:] + out[: len(var)] = var[:] return out + start = ensure_length(kwargs.pop("start"), nslabs, "start") end = ensure_length(kwargs.pop("end"), nslabs, "end") vacuum = np.asarray(kwargs.pop("vacuum")) - vacuums = np.full(nvacuums, 0.) + vacuums = np.full(nvacuums, 0.0) if vacuum.ndim == 0: vacuums[:] = vacuum else: - vacuums[:len(vacuum)] = vacuum - vacuums[len(vacuum):] = vacuum[-1] + vacuums[: len(vacuum)] = vacuum + vacuums[len(vacuum) :] = vacuum[-1] vacuums = vacuums.tolist() # We are now sure that there is a vacuum! @@ -206,7 +220,7 @@ def iter_func(key, layer): # layer is an iterator, convert to list layer = list(layer) if isinstance(layer[0], str): - layer = ''.join(layer) + layer = "".join(layer) elif len(layer) > 1: raise ValueError(f"Grouper returned long list {layer}") else: @@ -218,9 +232,11 @@ def iter_func(key, layer): # group stuff layers = [ iter_func(key, group) - for key, group in groupby(layers, - # group by vacuum positions and not vacuum positions - lambda l: 0 if is_vacuum(l) else 1) + for key, group in groupby( + layers, + # group by vacuum positions and not vacuum positions + lambda l: 0 if is_vacuum(l) else 1, + ) ] # Now we need to loop and create the things @@ -228,12 +244,15 @@ def iter_func(key, layer): ivacuum = 0 islab = 0 if layers[0] is None: - layers.pop(0) # vacuum specification - out = func(*args, - layers=layers.pop(0), - start=start.pop(0), - end=end.pop(0), - vacuum=None, **kwargs) + layers.pop(0) # vacuum specification + out = func( + *args, + layers=layers.pop(0), + start=start.pop(0), + end=end.pop(0), + vacuum=None, + **kwargs, + ) # add vacuum vacuum = Lattice([0, 0, vacuums.pop(0)]) out = out.add(vacuum, offset=(0, 0, vacuum.cell[2, 2])) @@ -241,11 +260,14 @@ def iter_func(key, layer): islab += 1 else: - out = func(*args, - layers=layers.pop(0), - start=start.pop(0), - end=end.pop(0), - vacuum=None, **kwargs) + out = func( + *args, + layers=layers.pop(0), + start=start.pop(0), + end=end.pop(0), + vacuum=None, + **kwargs, + ) islab += 1 while len(layers) > 0: @@ -257,11 +279,14 @@ def iter_func(key, layer): ivacuum += 1 out = out.add(vacuum) else: - geom = func(*args, - layers=layer, - start=start.pop(0), - end=end.pop(0), - vacuum=None, **kwargs) + geom = func( + *args, + layers=layer, + start=start.pop(0), + end=end.pop(0), + vacuum=None, + **kwargs, + ) out = out.append(geom, 2) islab += 1 @@ -275,16 +300,18 @@ def iter_func(key, layer): @set_module("sisl.geom") -def fcc_slab(alat: float, - atoms, - miller: Union[int, str, Tuple[int, int, int]], - layers=None, - vacuum: Union[float, Sequence[float]]=20., - *, - orthogonal: bool=False, - start=None, - end=None): - r""" Surface slab forming a face-centered cubic (FCC) crystal +def fcc_slab( + alat: float, + atoms, + miller: Union[int, str, Tuple[int, int, int]], + layers=None, + vacuum: Union[float, Sequence[float]] = 20.0, + *, + orthogonal: bool = False, + start=None, + end=None, +): + r"""Surface slab forming a face-centered cubic (FCC) crystal The slab layers are stacked along the :math:`z`-axis. The default stacking is the first layer as an A-layer, defined as the plane containing an atom at :math:`(x,y)=(0,0)`. @@ -394,20 +421,26 @@ def fcc_slab(alat: float, bcc_slab : Slab in BCC structure rocksalt_slab : Slab in rocksalt/halite structure """ - geom = _slab_with_vacuum(fcc_slab, alat, atoms, miller, - vacuum=vacuum, orthogonal=orthogonal, - layers=layers, - start=start, end=end) + geom = _slab_with_vacuum( + fcc_slab, + alat, + atoms, + miller, + vacuum=vacuum, + orthogonal=orthogonal, + layers=layers, + start=start, + end=end, + ) if geom is not None: return geom miller = _convert_miller(miller) if miller == (1, 0, 0): - info = _calc_info(start, end, layers, 2) - lattice = Lattice(np.array([0.5 ** 0.5, 0.5 ** 0.5, 0.5]) * alat) + lattice = Lattice(np.array([0.5**0.5, 0.5**0.5, 0.5]) * alat) g = Geometry([0, 0, 0], atoms=atoms, lattice=lattice) g = g.tile(info.nlayers, 2) @@ -416,10 +449,9 @@ def fcc_slab(alat: float, g.xyz[B::2] += (lattice.cell[0] + lattice.cell[1]) / 2 elif miller == (1, 1, 0): - info = _calc_info(start, end, layers, 2) - lattice = Lattice(np.array([1., 0.5, 0.125]) ** 0.5 * alat) + lattice = Lattice(np.array([1.0, 0.5, 0.125]) ** 0.5 * alat) g = Geometry([0, 0, 0], atoms=atoms, lattice=lattice) g = g.tile(info.nlayers, 2) @@ -428,14 +460,15 @@ def fcc_slab(alat: float, g.xyz[B::2] += (lattice.cell[0] + lattice.cell[1]) / 2 elif miller == (1, 1, 1): - info = _calc_info(start, end, layers, 3) if orthogonal: lattice = Lattice(np.array([0.5, 4 * 0.375, 1 / 3]) ** 0.5 * alat) - g = Geometry(np.array([[0, 0, 0], - [0.125, 0.375, 0]]) ** 0.5 * alat, - atoms=atoms, lattice=lattice) + g = Geometry( + np.array([[0, 0, 0], [0.125, 0.375, 0]]) ** 0.5 * alat, + atoms=atoms, + lattice=lattice, + ) g = g.tile(info.nlayers, 2) # slide ABC layers relative to each other @@ -443,14 +476,14 @@ def fcc_slab(alat: float, C = 2 * (info.offset + 2) % 6 vec = (3 * lattice.cell[0] + lattice.cell[1]) / 6 g.xyz[B::6] += vec - g.xyz[B+1::6] += vec + g.xyz[B + 1 :: 6] += vec g.xyz[C::6] += 2 * vec - g.xyz[C+1::6] += 2 * vec + g.xyz[C + 1 :: 6] += 2 * vec else: - lattice = Lattice(np.array([[0.5, 0, 0], - [0.125, 0.375, 0], - [0, 0, 1 / 3]]) ** 0.5 * alat) + lattice = Lattice( + np.array([[0.5, 0, 0], [0.125, 0.375, 0], [0, 0, 1 / 3]]) ** 0.5 * alat + ) g = Geometry([0, 0, 0], atoms=atoms, lattice=lattice) g = g.tile(info.nlayers, 2) @@ -469,16 +502,18 @@ def fcc_slab(alat: float, @set_module("sisl.geom") -def bcc_slab(alat: float, - atoms, - miller: Union[int, str, Tuple[int, int, int]], - layers=None, - vacuum: Union[float, Sequence[float]]=20., - *, - orthogonal: bool=False, - start=None, - end=None): - r""" Construction of a surface slab from a body-centered cubic (BCC) crystal +def bcc_slab( + alat: float, + atoms, + miller: Union[int, str, Tuple[int, int, int]], + layers=None, + vacuum: Union[float, Sequence[float]] = 20.0, + *, + orthogonal: bool = False, + start=None, + end=None, +): + r"""Construction of a surface slab from a body-centered cubic (BCC) crystal The slab layers are stacked along the :math:`z`-axis. The default stacking is the first layer as an A-layer, defined as the plane containing an atom at :math:`(x,y)=(0,0)`. @@ -533,17 +568,23 @@ def bcc_slab(alat: float, fcc_slab : Slab in FCC structure rocksalt_slab : Slab in rocksalt/halite structure """ - geom = _slab_with_vacuum(bcc_slab, alat, atoms, miller, - vacuum=vacuum, orthogonal=orthogonal, - layers=layers, - start=start, end=end) + geom = _slab_with_vacuum( + bcc_slab, + alat, + atoms, + miller, + vacuum=vacuum, + orthogonal=orthogonal, + layers=layers, + start=start, + end=end, + ) if geom is not None: return geom miller = _convert_miller(miller) if miller == (1, 0, 0): - info = _calc_info(start, end, layers, 2) lattice = Lattice(np.array([1, 1, 0.5]) * alat) @@ -555,26 +596,27 @@ def bcc_slab(alat: float, g.xyz[B::2] += (lattice.cell[0] + lattice.cell[1]) / 2 elif miller == (1, 1, 0): - info = _calc_info(start, end, layers, 2) if orthogonal: lattice = Lattice(np.array([1, 2, 0.5]) ** 0.5 * alat) - g = Geometry(np.array([[0, 0, 0], - [0.5, 0.5 ** 0.5, 0]]) * alat, - atoms=atoms, lattice=lattice) + g = Geometry( + np.array([[0, 0, 0], [0.5, 0.5**0.5, 0]]) * alat, + atoms=atoms, + lattice=lattice, + ) g = g.tile(info.nlayers, 2) # slide ABC layers relative to each other B = 2 * (info.offset + 1) % 4 vec = lattice.cell[1] / 2 g.xyz[B::4] += vec - g.xyz[B+1::4] += vec + g.xyz[B + 1 :: 4] += vec else: - lattice = Lattice(np.array([[1, 0, 0], - [0.5, 0.5 ** 0.5, 0], - [0, 0, 0.5 ** 0.5]]) * alat) + lattice = Lattice( + np.array([[1, 0, 0], [0.5, 0.5**0.5, 0], [0, 0, 0.5**0.5]]) * alat + ) g = Geometry([0, 0, 0], atoms=atoms, lattice=lattice) g = g.tile(info.nlayers, 2) @@ -583,14 +625,15 @@ def bcc_slab(alat: float, g.xyz[B::2] += lattice.cell[0] / 2 elif miller == (1, 1, 1): - info = _calc_info(start, end, layers, 3) if orthogonal: lattice = Lattice(np.array([2, 4 * 1.5, 1 / 12]) ** 0.5 * alat) - g = Geometry(np.array([[0, 0, 0], - [0.5, 1.5, 0]]) ** 0.5 * alat, - atoms=atoms, lattice=lattice) + g = Geometry( + np.array([[0, 0, 0], [0.5, 1.5, 0]]) ** 0.5 * alat, + atoms=atoms, + lattice=lattice, + ) g = g.tile(info.nlayers, 2) # slide ABC layers relative to each other @@ -598,13 +641,13 @@ def bcc_slab(alat: float, C = 2 * (info.offset + 2) % 6 vec = (lattice.cell[0] + lattice.cell[1]) / 3 for i in range(2): - g.xyz[B+i::6] += vec - g.xyz[C+i::6] += 2 * vec + g.xyz[B + i :: 6] += vec + g.xyz[C + i :: 6] += 2 * vec else: - lattice = Lattice(np.array([[2, 0, 0], - [0.5, 1.5, 0], - [0, 0, 1 / 12]]) ** 0.5 * alat) + lattice = Lattice( + np.array([[2, 0, 0], [0.5, 1.5, 0], [0, 0, 1 / 12]]) ** 0.5 * alat + ) g = Geometry([0, 0, 0], atoms=atoms, lattice=lattice) g = g.tile(info.nlayers, 2) @@ -623,16 +666,18 @@ def bcc_slab(alat: float, @set_module("sisl.geom") -def rocksalt_slab(alat: float, - atoms, - miller: Union[int, str, Tuple[int, int, int]], - layers=None, - vacuum: Union[float, Sequence[float]]=20., - *, - orthogonal: bool=False, - start=None, - end=None): - r""" Surface slab forming a rock-salt crystal (halite) +def rocksalt_slab( + alat: float, + atoms, + miller: Union[int, str, Tuple[int, int, int]], + layers=None, + vacuum: Union[float, Sequence[float]] = 20.0, + *, + orthogonal: bool = False, + start=None, + end=None, +): + r"""Surface slab forming a rock-salt crystal (halite) This structure is formed by two interlocked fcc crystals for each of the two elements. @@ -705,10 +750,17 @@ def rocksalt_slab(alat: float, fcc_slab : Slab in FCC structure (this slab is a combination of fcc slab structures) bcc_slab : Slab in BCC structure """ - geom = _slab_with_vacuum(rocksalt_slab, alat, atoms, miller, - vacuum=vacuum, orthogonal=orthogonal, - layers=layers, - start=start, end=end) + geom = _slab_with_vacuum( + rocksalt_slab, + alat, + atoms, + miller, + vacuum=vacuum, + orthogonal=orthogonal, + layers=layers, + start=start, + end=end, + ) if geom is not None: return geom @@ -719,8 +771,26 @@ def rocksalt_slab(alat: float, miller = _convert_miller(miller) - g1 = fcc_slab(alat, atoms[0], miller, layers=layers, vacuum=None, orthogonal=orthogonal, start=start, end=end) - g2 = fcc_slab(alat, atoms[1], miller, layers=layers, vacuum=None, orthogonal=orthogonal, start=start, end=end) + g1 = fcc_slab( + alat, + atoms[0], + miller, + layers=layers, + vacuum=None, + orthogonal=orthogonal, + start=start, + end=end, + ) + g2 = fcc_slab( + alat, + atoms[1], + miller, + layers=layers, + vacuum=None, + orthogonal=orthogonal, + start=start, + end=end, + ) if miller == (1, 0, 0): g2 = g2.move(np.array([0.5, 0.5, 0]) ** 0.5 * alat / 2) diff --git a/src/sisl/geom/tests/test_geom.py b/src/sisl/geom/tests/test_geom.py index ac97c5eef0..3112b55450 100644 --- a/src/sisl/geom/tests/test_geom.py +++ b/src/sisl/geom/tests/test_geom.py @@ -16,7 +16,6 @@ class CellDirect(Lattice): - @property def volume(self): return dot3(self.cell[0, :], cross3(self.cell[1, :], self.cell[2, :])) @@ -24,33 +23,34 @@ def volume(self): def is_right_handed(geometry): sc = CellDirect(geometry.lattice.cell) - return sc.volume > 0. + return sc.volume > 0.0 def test_basic(): - a = sc(2.52, Atom['Fe']) + a = sc(2.52, Atom["Fe"]) assert is_right_handed(a) - a = bcc(2.52, Atom['Fe']) + a = bcc(2.52, Atom["Fe"]) assert is_right_handed(a) - a = bcc(2.52, Atom['Fe'], orthogonal=True) - a = fcc(2.52, Atom['Au']) + a = bcc(2.52, Atom["Fe"], orthogonal=True) + a = fcc(2.52, Atom["Au"]) assert is_right_handed(a) - a = fcc(2.52, Atom['Au'], orthogonal=True) - a = hcp(2.52, Atom['Au']) + a = fcc(2.52, Atom["Au"], orthogonal=True) + a = hcp(2.52, Atom["Au"]) assert is_right_handed(a) - a = hcp(2.52, Atom['Au'], orthogonal=True) - a = rocksalt(5.64, ['Na', 'Cl']) + a = hcp(2.52, Atom["Au"], orthogonal=True) + a = rocksalt(5.64, ["Na", "Cl"]) assert is_right_handed(a) - a = rocksalt(5.64, [Atom('Na', R=3), Atom('Cl', R=4)], orthogonal=True) + a = rocksalt(5.64, [Atom("Na", R=3), Atom("Cl", R=4)], orthogonal=True) def test_flat(): a = graphene() assert is_right_handed(a) - graphene(atoms='C') + graphene(atoms="C") a = graphene(orthogonal=True) assert is_right_handed(a) + def test_flat_flakes(): g = graphene_flake(shells=0, bond=1.42) assert g.na == 6 @@ -65,11 +65,14 @@ def test_flat_flakes(): assert len(g.axyz(AtomNeighbours(min=2, max=2, R=1.44))) == 12 assert len(g.axyz(AtomNeighbours(min=3, max=3, R=1.44))) == 12 - bn = honeycomb_flake(shells=1, atoms=['B', 'N'], bond=1.42) + bn = honeycomb_flake(shells=1, atoms=["B", "N"], bond=1.42) assert bn.na == 24 assert np.allclose(bn.xyz, g.xyz) # Check that atoms are alternated. - assert len(bn.axyz(AtomZ(5) & AtomNeighbours(min=1, R=1.44, neighbour=AtomZ(5)))) == 0 + assert ( + len(bn.axyz(AtomZ(5) & AtomNeighbours(min=1, R=1.44, neighbour=AtomZ(5)))) == 0 + ) + def test_nanotube(): a = nanotube(1.42) @@ -88,40 +91,40 @@ def test_diamond(): def test_bilayer(): a = bilayer(1.42) assert is_right_handed(a) - bilayer(1.42, stacking='AA') - bilayer(1.42, stacking='BA') - bilayer(1.42, stacking='AB') + bilayer(1.42, stacking="AA") + bilayer(1.42, stacking="BA") + bilayer(1.42, stacking="AB") for m in range(7): bilayer(1.42, twist=(m, m + 1)) - bilayer(1.42, twist=(6, 7), layer='bottom') - bilayer(1.42, twist=(6, 7), layer='TOP') - bilayer(1.42, bottom_atoms=(Atom['B'], Atom['N']), twist=(6, 7)) + bilayer(1.42, twist=(6, 7), layer="bottom") + bilayer(1.42, twist=(6, 7), layer="TOP") + bilayer(1.42, bottom_atoms=(Atom["B"], Atom["N"]), twist=(6, 7)) bilayer(1.42, top_atoms=(Atom(5), Atom(7)), twist=(6, 7)) _, _ = bilayer(1.42, twist=(6, 7), ret_angle=True) with pytest.raises(ValueError): - bilayer(1.42, twist=(6, 7), layer='undefined') + bilayer(1.42, twist=(6, 7), layer="undefined") with pytest.raises(ValueError): - bilayer(1.42, twist=(6, 7), stacking='undefined') + bilayer(1.42, twist=(6, 7), stacking="undefined") with pytest.raises(ValueError): - bilayer(1.42, twist=('str', 7), stacking='undefined') + bilayer(1.42, twist=("str", 7), stacking="undefined") def test_nanoribbon(): for w in range(0, 5): - nanoribbon(w, 1.42, Atom(6), kind='armchair') - nanoribbon(w, 1.42, Atom(6), kind='zigzag') - nanoribbon(w, 1.42, (Atom(5), Atom(7)), kind='armchair') - a = nanoribbon(w, 1.42, (Atom(5), Atom(7)), kind='zigzag') + nanoribbon(w, 1.42, Atom(6), kind="armchair") + nanoribbon(w, 1.42, Atom(6), kind="zigzag") + nanoribbon(w, 1.42, (Atom(5), Atom(7)), kind="armchair") + a = nanoribbon(w, 1.42, (Atom(5), Atom(7)), kind="zigzag") assert is_right_handed(a) with pytest.raises(ValueError): - nanoribbon(6, 1.42, (Atom(5), Atom(7)), kind='undefined') + nanoribbon(6, 1.42, (Atom(5), Atom(7)), kind="undefined") with pytest.raises(ValueError): - nanoribbon('str', 1.42, (Atom(5), Atom(7)), kind='undefined') + nanoribbon("str", 1.42, (Atom(5), Atom(7)), kind="undefined") def test_graphene_nanoribbon(): @@ -140,11 +143,14 @@ def test_zgnr(): @pytest.mark.parametrize( - "W, invert_first", itertools.product(range(3, 20), [True, False]), + "W, invert_first", + itertools.product(range(3, 20), [True, False]), ) def test_heteroribbon_one_unit(W, invert_first): # Check that all ribbon widths are properly cut into one unit - geometry = heteroribbon([(W, 1)], bond=1.42, atoms=Atom(6, 1.43), invert_first=invert_first) + geometry = heteroribbon( + [(W, 1)], bond=1.42, atoms=Atom(6, 1.43), invert_first=invert_first + ) assert geometry.na == W @@ -158,7 +164,9 @@ def test_heteroribbon(): L = itertools.repeat(2) for Ws in combinations: - geom = heteroribbon(zip(Ws, L), bond=1.42, atoms=Atom(6, 1.43), align="auto", shift_quantum=True) + geom = heteroribbon( + zip(Ws, L), bond=1.42, atoms=Atom(6, 1.43), align="auto", shift_quantum=True + ) # Assert no dangling bonds. assert len(geom.asc2uc({"neighbours": 1})) == 0 @@ -172,7 +180,9 @@ def test_graphene_heteroribbon_errors(): # 7-open with 9 can only be perfectly aligned. graphene_heteroribbon([(7, 1), (9, 1)], align="center", on_lone_atom="raise") with pytest.raises(SislError): - graphene_heteroribbon([(7, 1), (9, 1, -1)], align="center", on_lone_atom="raise") + graphene_heteroribbon( + [(7, 1), (9, 1, -1)], align="center", on_lone_atom="raise" + ) # From the bottom graphene_heteroribbon([(7, 1), (9, 1, -1)], align="bottom", on_lone_atom="raise") with pytest.raises(SislError): @@ -182,10 +192,7 @@ def test_graphene_heteroribbon_errors(): with pytest.raises(SislError): graphene_heteroribbon([(7, 1), (9, 1, -1)], align="top", on_lone_atom="raise") - - grap_heteroribbon = partial( - graphene_heteroribbon, align="auto", shift_quantum=True - ) + grap_heteroribbon = partial(graphene_heteroribbon, align="auto", shift_quantum=True) # Odd section with open end with pytest.raises(SislError): @@ -209,105 +216,127 @@ def test_graphene_heteroribbon_errors(): grap_heteroribbon([(10, 2), (8, 2, -1)]) with pytest.raises(SislError): grap_heteroribbon([(10, 2), (8, 2, 1)]) - grap_heteroribbon([(10, 1), (8, 2, 1)],) #pbc=False) + grap_heteroribbon( + [(10, 1), (8, 2, 1)], + ) # pbc=False) with pytest.raises(SislError): - grap_heteroribbon([(10, 1), (8, 2, -1)],) #pbc=False) + grap_heteroribbon( + [(10, 1), (8, 2, -1)], + ) # pbc=False) def test_fcc_slab(): for o in [True, False]: - fcc_slab(alat=4.08, atoms='Au', miller=(1, 0, 0), orthogonal=o) - fcc_slab(4.08, 'Au', 100, orthogonal=o) - fcc_slab(4.08, 'Au', 110, orthogonal=o) - fcc_slab(4.08, 'Au', 111, orthogonal=o) - fcc_slab(4.08, 79, '100', layers=5, vacuum=None, orthogonal=o) - fcc_slab(4.08, 79, '110', layers=5, orthogonal=o) - fcc_slab(4.08, 79, '111', layers=5, start=1, orthogonal=o) - fcc_slab(4.08, 79, '111', layers=5, start='C', orthogonal=o) - fcc_slab(4.08, 79, '111', layers=5, end=2, orthogonal=o) - a = fcc_slab(4.08, 79, '111', layers=5, end='B', orthogonal=o) + fcc_slab(alat=4.08, atoms="Au", miller=(1, 0, 0), orthogonal=o) + fcc_slab(4.08, "Au", 100, orthogonal=o) + fcc_slab(4.08, "Au", 110, orthogonal=o) + fcc_slab(4.08, "Au", 111, orthogonal=o) + fcc_slab(4.08, 79, "100", layers=5, vacuum=None, orthogonal=o) + fcc_slab(4.08, 79, "110", layers=5, orthogonal=o) + fcc_slab(4.08, 79, "111", layers=5, start=1, orthogonal=o) + fcc_slab(4.08, 79, "111", layers=5, start="C", orthogonal=o) + fcc_slab(4.08, 79, "111", layers=5, end=2, orthogonal=o) + a = fcc_slab(4.08, 79, "111", layers=5, end="B", orthogonal=o) assert is_right_handed(a) with pytest.raises(ValueError): - fcc_slab(4.08, 'Au', 100, start=0, end=0) + fcc_slab(4.08, "Au", 100, start=0, end=0) with pytest.raises(ValueError): - fcc_slab(4.08, 'Au', 1000) + fcc_slab(4.08, "Au", 1000) with pytest.raises(NotImplementedError): - fcc_slab(4.08, 'Au', 200) + fcc_slab(4.08, "Au", 200) assert not np.allclose( - fcc_slab(5.64, 'Au', 100, end=1, layers='BABAB').xyz, - fcc_slab(5.64, 'Au', 100, end=1, layers=' BABAB ').xyz) + fcc_slab(5.64, "Au", 100, end=1, layers="BABAB").xyz, + fcc_slab(5.64, "Au", 100, end=1, layers=" BABAB ").xyz, + ) assert np.allclose( - fcc_slab(5.64, 'Au', 100, layers=' AB AB BA ', vacuum=2).xyz, - fcc_slab(5.64, 'Au', 100, layers=(None, 2, 2, 2, None), vacuum=2, start=(0, 0, 1)).xyz) + fcc_slab(5.64, "Au", 100, layers=" AB AB BA ", vacuum=2).xyz, + fcc_slab( + 5.64, "Au", 100, layers=(None, 2, 2, 2, None), vacuum=2, start=(0, 0, 1) + ).xyz, + ) assert np.allclose( - fcc_slab(5.64, 'Au', 100, layers=' AB AB BA ', vacuum=(2, 1)).xyz, - fcc_slab(5.64, 'Au', 100, layers=(None, 2, ' ', 2, None, 2, None), vacuum=(2, 1), end=(1, 1, 0)).xyz) + fcc_slab(5.64, "Au", 100, layers=" AB AB BA ", vacuum=(2, 1)).xyz, + fcc_slab( + 5.64, + "Au", + 100, + layers=(None, 2, " ", 2, None, 2, None), + vacuum=(2, 1), + end=(1, 1, 0), + ).xyz, + ) # example in documentation assert np.allclose( - fcc_slab(4., 'Au', 100, layers=(' ', 3, 5, 3), start=(0, 1, 0), vacuum=(10, 1, 2)).xyz, - fcc_slab(4., 'Au', 100, layers=' ABA BABAB ABA', vacuum=(10, 1, 2)).xyz) + fcc_slab( + 4.0, "Au", 100, layers=(" ", 3, 5, 3), start=(0, 1, 0), vacuum=(10, 1, 2) + ).xyz, + fcc_slab(4.0, "Au", 100, layers=" ABA BABAB ABA", vacuum=(10, 1, 2)).xyz, + ) assert np.allclose( - fcc_slab(4., 'Au', 100, layers=(' ', 3, 5, 3), start=(1, 0), vacuum=(10, 1, 2)).xyz, - fcc_slab(4., 'Au', 100, layers=' BAB ABABA ABA', vacuum=(10, 1, 2)).xyz) + fcc_slab( + 4.0, "Au", 100, layers=(" ", 3, 5, 3), start=(1, 0), vacuum=(10, 1, 2) + ).xyz, + fcc_slab(4.0, "Au", 100, layers=" BAB ABABA ABA", vacuum=(10, 1, 2)).xyz, + ) def test_bcc_slab(): for o in [True, False]: - bcc_slab(alat=4.08, atoms='Au', miller=(1, 0, 0), orthogonal=o) - bcc_slab(4.08, 'Au', 100, orthogonal=o) - bcc_slab(4.08, 'Au', 110, orthogonal=o) - bcc_slab(4.08, 'Au', 111, orthogonal=o) - bcc_slab(4.08, 79, '100', layers=5, vacuum=None, orthogonal=o) - bcc_slab(4.08, 79, '110', layers=5, orthogonal=o) - assert (bcc_slab(4.08, 79, '111', layers=5, start=1, orthogonal=o) - .equal( - bcc_slab(4.08, 79, '111', layers="BCABC", orthogonal=o) - )) - bcc_slab(4.08, 79, '111', layers="BCABC", start='B', orthogonal=o) - bcc_slab(4.08, 79, '111', layers="BCABC", start=1, orthogonal=o) - bcc_slab(4.08, 79, '111', layers=5, start='C', orthogonal=o) - bcc_slab(4.08, 79, '111', layers=5, end=2, orthogonal=o) - a = bcc_slab(4.08, 79, '111', layers=5, end='B', orthogonal=o) + bcc_slab(alat=4.08, atoms="Au", miller=(1, 0, 0), orthogonal=o) + bcc_slab(4.08, "Au", 100, orthogonal=o) + bcc_slab(4.08, "Au", 110, orthogonal=o) + bcc_slab(4.08, "Au", 111, orthogonal=o) + bcc_slab(4.08, 79, "100", layers=5, vacuum=None, orthogonal=o) + bcc_slab(4.08, 79, "110", layers=5, orthogonal=o) + assert bcc_slab(4.08, 79, "111", layers=5, start=1, orthogonal=o).equal( + bcc_slab(4.08, 79, "111", layers="BCABC", orthogonal=o) + ) + bcc_slab(4.08, 79, "111", layers="BCABC", start="B", orthogonal=o) + bcc_slab(4.08, 79, "111", layers="BCABC", start=1, orthogonal=o) + bcc_slab(4.08, 79, "111", layers=5, start="C", orthogonal=o) + bcc_slab(4.08, 79, "111", layers=5, end=2, orthogonal=o) + a = bcc_slab(4.08, 79, "111", layers=5, end="B", orthogonal=o) assert is_right_handed(a) with pytest.raises(ValueError): - bcc_slab(4.08, 'Au', 100, start=0, end=0) + bcc_slab(4.08, "Au", 100, start=0, end=0) with pytest.raises(ValueError): - bcc_slab(4.08, 'Au', 1000) + bcc_slab(4.08, "Au", 1000) with pytest.raises(NotImplementedError): - bcc_slab(4.08, 'Au', 200) + bcc_slab(4.08, "Au", 200) assert not np.allclose( - bcc_slab(5.64, 'Au', 100, end=1, layers='BABAB').xyz, - bcc_slab(5.64, 'Au', 100, end=1, layers=' BABAB ').xyz) + bcc_slab(5.64, "Au", 100, end=1, layers="BABAB").xyz, + bcc_slab(5.64, "Au", 100, end=1, layers=" BABAB ").xyz, + ) def test_rocksalt_slab(): rocksalt_slab(5.64, [Atom(11, R=3), Atom(17, R=4)], 100) - assert (rocksalt_slab(5.64, ['Na', 'Cl'], 100, layers=5) - .equal( - rocksalt_slab(5.64, ['Na', 'Cl'], 100, layers="ABABA") - )) - rocksalt_slab(5.64, ['Na', 'Cl'], 110, vacuum=None) - rocksalt_slab(5.64, ['Na', 'Cl'], 111, orthogonal=False) - rocksalt_slab(5.64, ['Na', 'Cl'], 111, orthogonal=True) - a = rocksalt_slab(5.64, 'Na', 100) + assert rocksalt_slab(5.64, ["Na", "Cl"], 100, layers=5).equal( + rocksalt_slab(5.64, ["Na", "Cl"], 100, layers="ABABA") + ) + rocksalt_slab(5.64, ["Na", "Cl"], 110, vacuum=None) + rocksalt_slab(5.64, ["Na", "Cl"], 111, orthogonal=False) + rocksalt_slab(5.64, ["Na", "Cl"], 111, orthogonal=True) + a = rocksalt_slab(5.64, "Na", 100) assert is_right_handed(a) with pytest.raises(ValueError): - rocksalt_slab(5.64, ['Na', 'Cl'], 100, start=0, end=0) + rocksalt_slab(5.64, ["Na", "Cl"], 100, start=0, end=0) with pytest.raises(ValueError): - rocksalt_slab(5.64, ['Na', 'Cl'], 1000) + rocksalt_slab(5.64, ["Na", "Cl"], 1000) with pytest.raises(NotImplementedError): - rocksalt_slab(5.64, ['Na', 'Cl'], 200) + rocksalt_slab(5.64, ["Na", "Cl"], 200) with pytest.raises(ValueError): - rocksalt_slab(5.64, ['Na', 'Cl'], 100, start=0, layers='BABAB') - rocksalt_slab(5.64, ['Na', 'Cl'], 100, start=1, layers='BABAB') + rocksalt_slab(5.64, ["Na", "Cl"], 100, start=0, layers="BABAB") + rocksalt_slab(5.64, ["Na", "Cl"], 100, start=1, layers="BABAB") with pytest.raises(ValueError): - rocksalt_slab(5.64, ['Na', 'Cl'], 100, end=0, layers='BABAB') - rocksalt_slab(5.64, ['Na', 'Cl'], 100, end=1, layers='BABAB') + rocksalt_slab(5.64, ["Na", "Cl"], 100, end=0, layers="BABAB") + rocksalt_slab(5.64, ["Na", "Cl"], 100, end=1, layers="BABAB") assert not np.allclose( - rocksalt_slab(5.64, ['Na', 'Cl'], 100, end=1, layers='BABAB').xyz, - rocksalt_slab(5.64, ['Na', 'Cl'], 100, end=1, layers=' BABAB ').xyz) + rocksalt_slab(5.64, ["Na", "Cl"], 100, end=1, layers="BABAB").xyz, + rocksalt_slab(5.64, ["Na", "Cl"], 100, end=1, layers=" BABAB ").xyz, + ) diff --git a/src/sisl/geometry.py b/src/sisl/geometry.py index ed9c740613..9fabb4a12d 100644 --- a/src/sisl/geometry.py +++ b/src/sisl/geometry.py @@ -36,7 +36,12 @@ from sisl._typing_ext.numpy import ArrayLike, NDArray if TYPE_CHECKING: - from sisl.typing import AtomsArgument, OrbitalsArgument, SileType, LatticeOrGeometryLike + from sisl.typing import ( + AtomsArgument, + OrbitalsArgument, + SileType, + LatticeOrGeometryLike, + ) from . import _array as _a from . import _plot as plt @@ -71,7 +76,7 @@ ) from .utils.mathematics import fnorm -__all__ = ['Geometry', "sgeom"] +__all__ = ["Geometry", "sgeom"] _log = logging.getLogger("sisl") _log.info(f"adding logger: {__name__}") @@ -93,18 +98,21 @@ def is_class(cls, name, case=True) -> bool: @set_module("sisl") -class Geometry(LatticeChild, _Dispatchs, - dispatchs=[ - ("new", ClassDispatcher("new", - obj_getattr="error", - instance_dispatcher=TypeDispatcher)), - ("to", ClassDispatcher("to", - obj_getattr="error", - type_dispatcher=None)) - ], - when_subclassing="copy", - ): - """ Holds atomic information, coordinates, species, lattice vectors +class Geometry( + LatticeChild, + _Dispatchs, + dispatchs=[ + ( + "new", + ClassDispatcher( + "new", obj_getattr="error", instance_dispatcher=TypeDispatcher + ), + ), + ("to", ClassDispatcher("to", obj_getattr="error", type_dispatcher=None)), + ], + when_subclassing="copy", +): + """Holds atomic information, coordinates, species, lattice vectors The `Geometry` class holds information regarding atomic coordinates, the atomic species, the corresponding lattice-vectors. @@ -174,21 +182,19 @@ class Geometry(LatticeChild, _Dispatchs, Atom : contained atoms are each an object of this """ - @deprecate_argument("sc", "lattice", - "argument sc has been deprecated in favor of lattice, please update your code.", - "0.15.0") - def __init__(self, - xyz: ArrayLike, - atoms=None, - lattice=None, - names=None): - + @deprecate_argument( + "sc", + "lattice", + "argument sc has been deprecated in favor of lattice, please update your code.", + "0.15.0", + ) + def __init__(self, xyz: ArrayLike, atoms=None, lattice=None, names=None): # Create the geometry coordinate, be aware that we do not copy! self.xyz = _a.asarrayd(xyz).reshape(-1, 3) # Default value if atoms is None: - atoms = Atom('H') + atoms = Atom("H") # Create the local Atoms object self._atoms = Atoms(atoms, na=self.na) @@ -202,7 +208,7 @@ def __init__(self, self._init_lattice(lattice) def _init_lattice(self, lattice): - """ Initializes the supercell by *calculating* the size if not supplied + """Initializes the supercell by *calculating* the size if not supplied If the supercell has not been passed we estimate the unit cell size by calculating the bond-length in each direction for a square @@ -217,22 +223,22 @@ def _init_lattice(self, lattice): # First create an initial guess for the supercell # It HAS to be VERY large to not interact - closest = self.close(0, R=(0., 0.4, 5.))[2] + closest = self.close(0, R=(0.0, 0.4, 5.0))[2] if len(closest) < 1: # We could not find any atoms very close, # hence we simply return and now it becomes # the users responsibility # We create a molecule box with +10 A in each direction - m, M = np.amin(self.xyz, axis=0), np.amax(self.xyz, axis=0) + 10. - self.set_lattice(M-m) + m, M = np.amin(self.xyz, axis=0), np.amax(self.xyz, axis=0) + 10.0 + self.set_lattice(M - m) return sc_cart = _a.zerosd([3]) cart = _a.zerosd([3]) for i in range(3): # Initialize cartesian direction - cart[i] = 1. + cart[i] = 1.0 # Get longest distance between atoms max_dist = np.amax(self.xyz[:, i]) - np.amin(self.xyz[:, i]) @@ -242,7 +248,7 @@ def _init_lattice(self, lattice): dd = np.abs(dot(dist, cart)) # Remove all below .4 - tmp_idx = (dd >= .4).nonzero()[0] + tmp_idx = (dd >= 0.4).nonzero()[0] if len(tmp_idx) > 0: # We have a success # Add the bond-distance in the Cartesian direction @@ -251,84 +257,84 @@ def _init_lattice(self, lattice): else: # Default to LARGE array so as no # interaction occurs (it may be 2D) - sc_cart[i] = max(10., max_dist) - cart[i] = 0. + sc_cart[i] = max(10.0, max_dist) + cart[i] = 0.0 # Re-set the supercell to the newly found one self.set_lattice(sc_cart) @property def atoms(self) -> Atoms: - """ Atoms for the geometry (`Atoms` object) """ + """Atoms for the geometry (`Atoms` object)""" return self._atoms @property def names(self): - """ The named index specifier """ + """The named index specifier""" return self._names @property def q0(self) -> float: - """ Total initial charge in this geometry (sum of q0 in all atoms) """ + """Total initial charge in this geometry (sum of q0 in all atoms)""" return self.atoms.q0.sum() @property def mass(self) -> ndarray: - """ The mass of all atoms as an array """ + """The mass of all atoms as an array""" return self.atoms.mass - def maxR(self, all: bool=False) -> float: - """ Maximum orbital range of the atoms """ + def maxR(self, all: bool = False) -> float: + """Maximum orbital range of the atoms""" return self.atoms.maxR(all) @property def na(self) -> int: - """ Number of atoms in geometry """ + """Number of atoms in geometry""" return self.xyz.shape[0] @property def na_s(self) -> int: - """ Number of supercell atoms """ + """Number of supercell atoms""" return self.na * self.n_s def __len__(self) -> int: - """ Number of atoms in geometry in unit cell""" + """Number of atoms in geometry in unit cell""" return self.na @property def no(self) -> int: - """ Number of orbitals in unit cell """ + """Number of orbitals in unit cell""" return self.atoms.no @property def no_s(self) -> int: - """ Number of supercell orbitals """ + """Number of supercell orbitals""" return self.no * self.n_s @property def firsto(self) -> NDArray[np.int32]: - """ The first orbital on the corresponding atom """ + """The first orbital on the corresponding atom""" return self.atoms.firsto @property def lasto(self) -> NDArray[np.int32]: - """ The last orbital on the corresponding atom """ + """The last orbital on the corresponding atom""" return self.atoms.lasto @property def orbitals(self) -> ndarray: - """ List of orbitals per atom """ + """List of orbitals per atom""" return self.atoms.orbitals ## End size of geometry @property def fxyz(self) -> NDArray[np.float64]: - """ Returns geometry coordinates in fractional coordinates """ + """Returns geometry coordinates in fractional coordinates""" return dot(self.xyz, self.icell.T) def __setitem__(self, atoms, value): - """ Specify geometry coordinates """ + """Specify geometry coordinates""" if isinstance(atoms, str): self.names.add_name(atoms, value) elif isinstance(value, str): @@ -336,7 +342,7 @@ def __setitem__(self, atoms, value): @singledispatchmethod def __getitem__(self, atoms) -> ndarray: - """ Geometry coordinates (allows supercell indices) """ + """Geometry coordinates (allows supercell indices)""" return self.axyz(atoms) @__getitem__.register @@ -353,7 +359,7 @@ def _(self, atoms: tuple) -> ndarray: @singledispatchmethod def _sanitize_atoms(self, atoms) -> ndarray: - """ Converts an `atoms` to index under given inputs + """Converts an `atoms` to index under given inputs `atoms` may be one of the following: @@ -390,6 +396,7 @@ def _(self, atoms: Atom) -> ndarray: def _(self, atoms) -> ndarray: # First do categorization cat = atoms.categorize(self) + def m(cat): for ia, c in enumerate(cat): if c == None: @@ -397,6 +404,7 @@ def m(cat): pass else: yield ia + return _a.fromiterl(m(cat)) @_sanitize_atoms.register @@ -414,7 +422,7 @@ def _(self, atoms: Shape) -> ndarray: @singledispatchmethod def _sanitize_orbs(self, orbitals) -> ndarray: - """ Converts an `orbital` to index under given inputs + """Converts an `orbital` to index under given inputs `orbital` may be one of the following: @@ -453,16 +461,19 @@ def _(self, orbitals: Shape) -> ndarray: @_sanitize_orbs.register def _(self, orbitals: dict) -> ndarray: - """ A dict has atoms as keys """ + """A dict has atoms as keys""" + def conv(atom, orbs): atom = self._sanitize_atoms(atom) return np.add.outer(self.firsto[atom], orbs).ravel() - return np.concatenate(tuple(conv(atom, orbs) for atom, orbs in orbitals.items())) - def as_primary(self, - na_primary: int, - axes=(0, 1, 2), - ret_super: bool=False) -> Union[Geometry, Tuple[Geometry, Lattice]]: + return np.concatenate( + tuple(conv(atom, orbs) for atom, orbs in orbitals.items()) + ) + + def as_primary( + self, na_primary: int, axes=(0, 1, 2), ret_super: bool = False + ) -> Union[Geometry, Tuple[Geometry, Lattice]]: """Reduce the geometry to the primary unit-cell comprising `na_primary` atoms This will basically try and find the tiling/repetitions required for the geometry to only have @@ -491,8 +502,10 @@ def as_primary(self, """ na = len(self) if na % na_primary != 0: - raise ValueError(f'{self.__class__.__name__}.as_primary requires the number of atoms to be divisable by the ' - 'total number of atoms.') + raise ValueError( + f"{self.__class__.__name__}.as_primary requires the number of atoms to be divisable by the " + "total number of atoms." + ) axes = _a.arrayi(axes) @@ -515,7 +528,6 @@ def as_primary(self, n_bin = n_supercells while n_bin > 1: - # Create bins bins = np.linspace(0, 1, n_bin + 1) @@ -547,7 +559,9 @@ def as_primary(self, # Check that the number of supercells match if np.prod(supercell) != n_supercells: - raise SislError(f'{self.__class__.__name__}.as_primary could not determine the optimal supercell.') + raise SislError( + f"{self.__class__.__name__}.as_primary could not determine the optimal supercell." + ) # Cut down the supercell (TODO this does not correct the number of supercell connections!) lattice = self.lattice.copy() @@ -569,7 +583,7 @@ def as_primary(self, fxyz += min_fxyz * 0.05 # Find all fractional indices that are below 1 - ind = np.logical_and.reduce(fxyz < 1., axis=1).nonzero()[0] + ind = np.logical_and.reduce(fxyz < 1.0, axis=1).nonzero()[0] geom = self.sub(ind) geom.set_lattice(lattice) @@ -606,6 +620,7 @@ def as_supercell(self) -> Geometry: # be made to look like the `self` supercell indices isc_sc = np.rint(sc.xyz[::na] @ self.icell.T - f0).astype(np.int32) isc_self = self.a2isc(np.arange(self.n_s) * na) + def new_sub(isc): return (abs(isc_sc - isc).sum(1) == 0).nonzero()[0][0] @@ -639,9 +654,7 @@ def reduce(self) -> None: """ self._atoms = self.atoms.reduce(in_place=True) - def rij(self, - ia: AtomsArgument, - ja: AtomsArgument) -> ndarray: + def rij(self, ia: AtomsArgument, ja: AtomsArgument) -> ndarray: r"""Distance between atom `ia` and `ja`, atoms can be in super-cell indices Returns the distance between two atoms: @@ -659,13 +672,11 @@ def rij(self, R = self.Rij(ia, ja) if len(R.shape) == 1: - return (R[0] ** 2. + R[1] ** 2 + R[2] ** 2) ** .5 + return (R[0] ** 2.0 + R[1] ** 2 + R[2] ** 2) ** 0.5 return fnorm(R) - def Rij(self, - ia: AtomsArgument, - ja: AtomsArgument) -> ndarray: + def Rij(self, ia: AtomsArgument, ja: AtomsArgument) -> ndarray: r"""Vector between atom `ia` and `ja`, atoms can be in super-cell indices Returns the vector between two atoms: @@ -690,9 +701,7 @@ def Rij(self, return xj - xi[None, :] - def orij(self, - orbitals1: OrbitalsArgument, - orbitals2: OrbitalsArgument) -> ndarray: + def orij(self, orbitals1: OrbitalsArgument, orbitals2: OrbitalsArgument) -> ndarray: r"""Distance between orbital `orbitals1` and `orbitals2`, orbitals can be in super-cell indices Returns the distance between two orbitals: @@ -709,9 +718,7 @@ def orij(self, """ return self.rij(self.o2a(orbitals1), self.o2a(orbitals2)) - def oRij(self, - orbitals1: OrbitalsArgument, - orbitals2: OrbitalsArgument) -> ndarray: + def oRij(self, orbitals1: OrbitalsArgument, orbitals2: OrbitalsArgument) -> ndarray: r"""Vector between orbital `orbitals1` and `orbitals2`, orbitals can be in super-cell indices Returns the vector between two orbitals: @@ -729,10 +736,8 @@ def oRij(self, return self.Rij(self.o2a(orbitals1), self.o2a(orbitals2)) @staticmethod - def read(sile: SileType, - *args, - **kwargs) -> Geometry: - """ Reads geometry from the `Sile` using `Sile.read_geometry` + def read(sile: SileType, *args, **kwargs) -> Geometry: + """Reads geometry from the `Sile` using `Sile.read_geometry` Parameters ---------- @@ -747,17 +752,15 @@ def read(sile: SileType, # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_geometry(*args, **kwargs) else: - with get_sile(sile, mode='r') as fh: + with get_sile(sile, mode="r") as fh: return fh.read_geometry(*args, **kwargs) - def write(self, - sile: SileType, - *args, - **kwargs) -> None: - """ Writes geometry to the `Sile` using `sile.write_geometry` + def write(self, sile: SileType, *args, **kwargs) -> None: + """Writes geometry to the `Sile` using `sile.write_geometry` Parameters ---------- @@ -775,22 +778,28 @@ def write(self, # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): sile.write_geometry(self, *args, **kwargs) else: - with get_sile(sile, mode='w') as fh: + with get_sile(sile, mode="w") as fh: fh.write_geometry(self, *args, **kwargs) def __str__(self) -> str: - """ str of the object """ - s = self.__class__.__name__ + f'{{na: {self.na}, no: {self.no},\n ' - s += str(self.atoms).replace('\n', '\n ') + """str of the object""" + s = self.__class__.__name__ + f"{{na: {self.na}, no: {self.no},\n " + s += str(self.atoms).replace("\n", "\n ") if len(self.names) > 0: - s += ',\n ' + str(self.names).replace('\n', '\n ') - return (s + ',\n maxR: {0:.5f},\n {1}\n}}'.format(self.maxR(), str(self.lattice).replace('\n', '\n '))).strip() + s += ",\n " + str(self.names).replace("\n", "\n ") + return ( + s + + ",\n maxR: {0:.5f},\n {1}\n}}".format( + self.maxR(), str(self.lattice).replace("\n", "\n ") + ) + ).strip() def __repr__(self) -> str: - """ A simple, short string representation. """ + """A simple, short string representation.""" return f"<{self.__module__}.{self.__class__.__name__} na={self.na}, no={self.no}, nsc={self.nsc}>" def iter(self) -> Iterator[int]: @@ -815,8 +824,9 @@ def iter(self) -> Iterator[int]: __iter__ = iter - def iter_species(self, - atoms: Optional[AtomsArgument]=None) -> Iterator[int, Atom, int]: + def iter_species( + self, atoms: Optional[AtomsArgument] = None + ) -> Iterator[int, Atom, int]: """Iterator over all atoms (or a subset) and species as a tuple in this geometry >>> for ia, a, idx_specie in self.iter_species(): @@ -844,9 +854,9 @@ def iter_species(self, for ia in self._sanitize_atoms(atoms).ravel(): yield ia, self.atoms[ia], self.atoms.specie[ia] - def iter_orbitals(self, - atoms: Optional[AtomsArgument]=None, - local: bool=True) -> Iterator[int, int]: + def iter_orbitals( + self, atoms: Optional[AtomsArgument] = None, local: bool = True + ) -> Iterator[int, int]: r"""Returns an iterator over all atoms and their associated orbitals >>> for ia, io in self.iter_orbitals(): @@ -886,18 +896,19 @@ def iter_orbitals(self, else: atoms = self._sanitize_atoms(atoms).ravel() if local: - for ia, io1, io2 in zip(atoms, self.firsto[atoms], self.lasto[atoms] + 1): + for ia, io1, io2 in zip( + atoms, self.firsto[atoms], self.lasto[atoms] + 1 + ): for io in range(io2 - io1): yield ia, io else: - for ia, io1, io2 in zip(atoms, self.firsto[atoms], self.lasto[atoms] + 1): + for ia, io1, io2 in zip( + atoms, self.firsto[atoms], self.lasto[atoms] + 1 + ): for io in range(io1, io2): yield ia, io - def iR(self, - na: int=1000, - iR: int=20, - R: Optional[float]=None) -> int: + def iR(self, na: int = 1000, iR: int = 20, R: Optional[float] = None) -> int: """Return an integer number of maximum radii (``self.maxR()``) which holds approximately `na` atoms Parameters @@ -920,21 +931,25 @@ def iR(self, if R is None: R = self.maxR() + 0.001 if R < 0: - raise ValueError(f"{self.__class__.__name__}.iR unable to determine a number of atoms within a sphere with negative radius, is maxR() defined?") + raise ValueError( + f"{self.__class__.__name__}.iR unable to determine a number of atoms within a sphere with negative radius, is maxR() defined?" + ) # Number of atoms within 20 * R naiR = max(1, len(self.close(ia, R=R * iR))) # Convert to na atoms spherical radii - iR = int(4 / 3 * np.pi * R ** 3 / naiR * na) + iR = int(4 / 3 * np.pi * R**3 / naiR * na) return max(2, iR) - def iter_block_rand(self, - iR: int=20, - R: Optional[float]=None, - atoms: Optional[AtomsArgument]=None) -> Iterator[Tuple[ndarray, ndarray]]: - """Perform the *random* block-iteration by randomly selecting the next center of block """ + def iter_block_rand( + self, + iR: int = 20, + R: Optional[float] = None, + atoms: Optional[AtomsArgument] = None, + ) -> Iterator[Tuple[ndarray, ndarray]]: + """Perform the *random* block-iteration by randomly selecting the next center of block""" # We implement yields as we can then do nested iterators # create a boolean array @@ -950,7 +965,7 @@ def iter_block_rand(self, not_passed_N = np.sum(not_passed) if iR < 2: - raise SislError(f'{self.__class__.__name__}.iter_block_rand too small iR!') + raise SislError(f"{self.__class__.__name__}.iter_block_rand too small iR!") if R is None: R = self.maxR() + 0.001 @@ -959,7 +974,6 @@ def iter_block_rand(self, # loop until all passed are true while not_passed_N > 0: - # Take a random non-passed element all_true = not_passed.nonzero()[0] @@ -980,8 +994,9 @@ def iter_block_rand(self, # Get unit-cell atoms, we are drawing a circle, and this # circle only encompasses those already in the unit-cell. - all_idx[1] = np.union1d(self.sc2uc(all_idx[0], unique=True), - self.sc2uc(all_idx[1], unique=True)) + all_idx[1] = np.union1d( + self.sc2uc(all_idx[0], unique=True), self.sc2uc(all_idx[1], unique=True) + ) # If we translated stuff into the unit-cell, we could end up in situations # where the supercell atom is in the circle, but not the UC-equivalent # of that one. @@ -1005,13 +1020,14 @@ def iter_block_rand(self, if np.any(not_passed): print(not_passed.nonzero()[0]) print(np.sum(not_passed), len(self)) - raise SislError(f'{self.__class__.__name__}.iter_block_rand error on iterations. Not all atoms have been visited.') + raise SislError( + f"{self.__class__.__name__}.iter_block_rand error on iterations. Not all atoms have been visited." + ) - def iter_block_shape(self, - shape=None, - iR: int=20, - atoms: Optional[AtomsArgument]=None) -> Iterator[Tuple[ndarray, ndarray]]: - """Perform the *grid* block-iteration by looping a grid """ + def iter_block_shape( + self, shape=None, iR: int = 20, atoms: Optional[AtomsArgument] = None + ) -> Iterator[Tuple[ndarray, ndarray]]: + """Perform the *grid* block-iteration by looping a grid""" # We implement yields as we can then do nested iterators # create a boolean array @@ -1027,22 +1043,23 @@ def iter_block_shape(self, not_passed_N = np.sum(not_passed) if iR < 2: - raise SislError(f'{self.__class__.__name__}.iter_block_shape too small iR!') + raise SislError(f"{self.__class__.__name__}.iter_block_shape too small iR!") R = self.maxR() + 0.001 if shape is None: # we default to the Cube shapes - dS = (Cube((iR - 0.5) * R), - Cube((iR + 1.501) * R)) + dS = (Cube((iR - 0.5) * R), Cube((iR + 1.501) * R)) else: if isinstance(shape, Shape): dS = (shape,) else: dS = tuple(shape) if len(dS) == 1: - dS += (dS[0].expand(R), ) + dS += (dS[0].expand(R),) if len(dS) != 2: - raise ValueError(f'{self.__class__.__name__}.iter_block_shape, number of Shapes *must* be one or two') + raise ValueError( + f"{self.__class__.__name__}.iter_block_shape, number of Shapes *must* be one or two" + ) # Now create the Grid # convert the radius to a square Grid @@ -1055,19 +1072,21 @@ def iter_block_shape(self, # Sphere and Cube for s in dS: if not isinstance(s, (Cube, Sphere)): - raise ValueError(f'{self.__class__.__name__}.iter_block_shape currently only works for ' - 'Cube or Sphere objects. Please change sources.') + raise ValueError( + f"{self.__class__.__name__}.iter_block_shape currently only works for " + "Cube or Sphere objects. Please change sources." + ) # Retrieve the internal diameter if isinstance(dS[0], Cube): ir = dS[0].edge_length elif isinstance(dS[0], Sphere): - ir = [dS[0].radius * 0.5 ** 0.5 * 2] * 3 + ir = [dS[0].radius * 0.5**0.5 * 2] * 3 elif isinstance(dS[0], Shape): # Convert to spheres (which probably should be cubes for performance) dS = [s.to.Sphere() for s in dS] # Now do the same with spheres - ir = [dS[0].radius * 0.5 ** 0.5 * 2] * 3 + ir = [dS[0].radius * 0.5**0.5 * 2] * 3 # Figure out number of segments in each iteration # (minimum 1) @@ -1092,7 +1111,6 @@ def iter_block_shape(self, # Now we loop in each direction for x, y, z in product(range(ixyz[0]), range(ixyz[1]), range(ixyz[2])): - # Create the new center center = xyz_m + [x * dxyz[0], y * dxyz[1], z * dxyz[2]] # Correct in case the iteration steps across the maximum @@ -1106,8 +1124,9 @@ def iter_block_shape(self, # Get unit-cell atoms, we are drawing a circle, and this # circle only encompasses those already in the unit-cell. - all_idx[1] = np.union1d(self.sc2uc(all_idx[0], unique=True), - self.sc2uc(all_idx[1], unique=True)) + all_idx[1] = np.union1d( + self.sc2uc(all_idx[0], unique=True), self.sc2uc(all_idx[1], unique=True) + ) # If we translated stuff into the unit-cell, we could end up in situations # where the supercell atom is in the circle, but not the UC-equivalent # of that one. @@ -1130,14 +1149,18 @@ def iter_block_shape(self, if np.any(not_passed): not_passed = not_passed.nonzero()[0] - raise SislError(f"{self.__class__.__name__}.iter_block_shape error on iterations. Not all atoms have been visited " - f"{not_passed}") - - def iter_block(self, - iR: int=20, - R: Optional[float]=None, - atoms: Optional[AtomsArgument]=None, - method: str='rand') -> Iterator[Tuple[ndarray, ndarray]]: + raise SislError( + f"{self.__class__.__name__}.iter_block_shape error on iterations. Not all atoms have been visited " + f"{not_passed}" + ) + + def iter_block( + self, + iR: int = 20, + R: Optional[float] = None, + atoms: Optional[AtomsArgument] = None, + method: str = "rand", + ) -> Iterator[Tuple[ndarray, ndarray]]: """Iterator for performance critical loops NOTE: This requires that `R` has been set correctly as the maximum interaction range. @@ -1177,7 +1200,7 @@ def iter_block(self, atoms that needs searching """ if iR < 2: - raise SislError(f'{self.__class__.__name__}.iter_block too small iR!') + raise SislError(f"{self.__class__.__name__}.iter_block too small iR!") method = method.lower() if method in ("rand", "random"): @@ -1187,30 +1210,37 @@ def iter_block(self, R = self.maxR() + 0.001 # Create shapes - if method == 'sphere': - dS = (Sphere((iR - 0.5) * R), - Sphere((iR + 0.501) * R)) - elif method == 'cube': - dS = (Cube((2 * iR - 0.5) * R), - # we need an extra R here since it needs to extend on both sides - Cube((2 * iR + 1.501) * R)) + if method == "sphere": + dS = (Sphere((iR - 0.5) * R), Sphere((iR + 0.501) * R)) + elif method == "cube": + dS = ( + Cube((2 * iR - 0.5) * R), + # we need an extra R here since it needs to extend on both sides + Cube((2 * iR + 1.501) * R), + ) yield from self.iter_block_shape(dS) else: - raise ValueError(f"{self.__class__.__name__}.iter_block got unexpected 'method' argument: {method}") + raise ValueError( + f"{self.__class__.__name__}.iter_block got unexpected 'method' argument: {method}" + ) def copy(self) -> Geometry: """Create a new object with the same content (a copy).""" - g = self.__class__(np.copy(self.xyz), atoms=self.atoms.copy(), lattice=self.lattice.copy()) + g = self.__class__( + np.copy(self.xyz), atoms=self.atoms.copy(), lattice=self.lattice.copy() + ) g._names = self.names.copy() return g - def overlap(self, - other: GeometryLikeType, - eps: float=0.1, - offset=(0., 0., 0.), - offset_other=(0., 0., 0.)) -> Tuple[ndarray, ndarray]: - """ Calculate the overlapping indices between two geometries + def overlap( + self, + other: GeometryLikeType, + eps: float = 0.1, + offset=(0.0, 0.0, 0.0), + offset_other=(0.0, 0.0, 0.0), + ) -> Tuple[ndarray, ndarray]: + """Calculate the overlapping indices between two geometries Find equivalent atoms (in the primary unit-cell only) in two geometries. This routine finds which atoms have the same atomic positions in `self` and `other`. @@ -1265,7 +1295,7 @@ def overlap(self, return _a.arrayi(idx_self), _a.arrayi(idx_other) def sort(self, **kwargs) -> Union[Geometry, Tuple[Geometry, List]]: - r""" Sort atoms in a nested fashion according to various criteria + r"""Sort atoms in a nested fashion according to various criteria There are many ways to sort a `Geometry`. - by Cartesian coordinates, `axis` @@ -1468,15 +1498,18 @@ def sort(self, **kwargs) -> Union[Geometry, Tuple[Geometry, List]]: >>> geom.sort(atol=0.001, axis=0, ret_atoms=True)[1] [[0], [1], [2], [3], [4]] """ + # We need a way to easily handle nested lists # This small class handles lists and allows appending nested lists # while flattening them. class NestedList: - __slots__ = ('_idx', ) + __slots__ = ("_idx",) + def __init__(self, idx=None, sort=False): self._idx = [] if not idx is None: self.append(idx, sort) + def append(self, idx, sort=False): if isinstance(idx, (tuple, list, ndarray)): if isinstance(idx[0], (tuple, list, ndarray)): @@ -1490,31 +1523,36 @@ def append(self, idx, sort=False): self._idx.append(np.sort(idx)) else: self._idx.append(np.asarray(idx)) + def __iter__(self): yield from self._idx + def __len__(self): return len(self._idx) + def ravel(self): if len(self) == 0: return np.array([], dtype=np.int64) return concatenate([i for i in self]).ravel() + def tolist(self): return self._idx + def __str__(self): if len(self) == 0: return f"{self.__class__.__name__}{{empty}}" - out = ',\n '.join(map(lambda x: str(x.tolist()), self)) + out = ",\n ".join(map(lambda x: str(x.tolist()), self)) return f"{self.__class__.__name__}{{\n {out}}}" def _sort(val, atoms, **kwargs): - """ We do not sort according to lexsort """ + """We do not sort according to lexsort""" if len(val) <= 1: # no values to sort return atoms # control ascend vs descending - ascend = kwargs['ascend'] - atol = kwargs['atol'] + ascend = kwargs["ascend"] + atol = kwargs["atol"] new_atoms = NestedList() for atom in atoms: @@ -1535,13 +1573,15 @@ def _sort(val, atoms, **kwargs): # Functions allowed by external users funcs = dict() + def _axis(axis, atoms, **kwargs): - """ Cartesian coordinate sort """ + """Cartesian coordinate sort""" if isinstance(axis, int): axis = (axis,) for ax in axis: atoms = _sort(self.xyz[:, ax], atoms, **kwargs) return atoms + funcs["axis"] = _axis def _lattice(lattice, atoms, **kwargs): @@ -1555,6 +1595,7 @@ def _lattice(lattice, atoms, **kwargs): for ax in lattice: atoms = _sort(fxyz[:, ax] * self.lattice.length[ax], atoms, **kwargs) return atoms + funcs["lattice"] = _lattice def _vector(vector, atoms, **kwargs): @@ -1572,12 +1613,14 @@ def _vector(vector, atoms, **kwargs): vector /= fnorm(vector) # Perform a . b^ == scalar projection return _sort(self.xyz.dot(vector), atoms, **kwargs) + funcs["vector"] = _vector def _funcs(funcs, atoms, **kwargs): """ User defined function (tuple/list of function) """ + def _func(func, atoms, kwargs): nl = NestedList() for atom in atoms: @@ -1593,6 +1636,7 @@ def _func(func, atoms, kwargs): for func in funcs: atoms = _func(func, atoms, kwargs) return atoms + funcs["func"] = _funcs def _func_sort(funcs, atoms, **kwargs): @@ -1604,6 +1648,7 @@ def _func_sort(funcs, atoms, **kwargs): for func in funcs: atoms = _sort(func(self), atoms, **kwargs) return atoms + funcs["func_sort"] = _func_sort def _group_vals(vals, groups, atoms, **kwargs): @@ -1698,15 +1743,17 @@ def _group(method_group, atoms, **kwargs): vals = [getattr(a, method) for a in self.atoms] else: - raise ValueError(f"{self.__class__.__name__}.sort group only supports attributes that can be fetched from Atom objects, some are [Z, species, tag, symbol, mass, ...] and more") + raise ValueError( + f"{self.__class__.__name__}.sort group only supports attributes that can be fetched from Atom objects, some are [Z, species, tag, symbol, mass, ...] and more" + ) return _group_vals(np.array(vals), groups, atoms, **kwargs) funcs["group"] = _group def stripint(s): - """ Remove integers from end of string -> Allow multiple arguments """ - if s[-1] in '0123456789': + """Remove integers from end of string -> Allow multiple arguments""" + if s[-1] in "0123456789": return stripint(s[:-1]) return s @@ -1716,18 +1763,18 @@ def stripint(s): # We also allow specific keys for specific methods func_kw = dict() - func_kw['ascend'] = True - func_kw['atol'] = 1e-9 + func_kw["ascend"] = True + func_kw["atol"] = 1e-9 def update_flag(kw, arg, val): - if arg in ['ascending', 'ascend']: - kw['ascend'] = val + if arg in ["ascending", "ascend"]: + kw["ascend"] = val return True - elif arg in ['descending', 'descend']: - kw['ascend'] = not val + elif arg in ["descending", "descend"]: + kw["ascend"] = not val return True - elif arg == 'atol': - kw['atol'] = val + elif arg == "atol": + kw["atol"] = val return True return False @@ -1737,14 +1784,16 @@ def update_flag(kw, arg, val): # In case the user just did geometry.sort, it will default to sort x, y, z if len(kwargs) == 0: - kwargs['axis'] = (0, 1, 2) + kwargs["axis"] = (0, 1, 2) for key_int, method in kwargs.items(): key = stripint(key_int) if update_flag(func_kw, key, method): continue if not key in funcs: - raise ValueError(f"{self.__class__.__name__}.sort unrecognized keyword '{key}' ('{key_int}')") + raise ValueError( + f"{self.__class__.__name__}.sort unrecognized keyword '{key}' ('{key_int}')" + ) # call sorting algorithm and retrieve new grouped sorting atoms = funcs[key](method, atoms, **func_kw) @@ -1761,10 +1810,12 @@ def update_flag(kw, arg, val): return self.sub(atoms_flat), atoms.tolist() return self.sub(atoms_flat) - def optimize_nsc(self, - axes: Optional[Union[int, Sequence[int]]]=None, - R: Optional[float]=None) -> ndarray: - """ Optimize the number of supercell connections based on ``self.maxR()`` + def optimize_nsc( + self, + axes: Optional[Union[int, Sequence[int]]] = None, + R: Optional[float] = None, + ) -> ndarray: + """Optimize the number of supercell connections based on ``self.maxR()`` After this routine the number of supercells may not necessarily be the same. @@ -1788,9 +1839,11 @@ def optimize_nsc(self, R = self.maxR() + 0.001 if R < 0: R = 0.00001 - warn(f"{self.__class__.__name__}" - ".optimize_nsc could not determine the radius from the " - "internal atoms (defaulting to zero radius).") + warn( + f"{self.__class__.__name__}" + ".optimize_nsc could not determine the radius from the " + "internal atoms (defaulting to zero radius)." + ) ic = self.icell nrc = 1 / fnorm(ic) @@ -1837,7 +1890,7 @@ def optimize_nsc(self, return nsc def sub(self, atoms: AtomsArgument) -> Geometry: - """ Create a new `Geometry` with a subset of this `Geometry` + """Create a new `Geometry` with a subset of this `Geometry` Indices passed *MUST* be unique. @@ -1854,11 +1907,12 @@ def sub(self, atoms: AtomsArgument) -> Geometry: remove : the negative of this routine, i.e. remove a subset of atoms """ atoms = self.sc2uc(atoms) - return self.__class__(self.xyz[atoms, :], atoms=self.atoms.sub(atoms), lattice=self.lattice.copy()) + return self.__class__( + self.xyz[atoms, :], atoms=self.atoms.sub(atoms), lattice=self.lattice.copy() + ) - def sub_orbital(self, atoms: AtomsArgument, - orbitals: OrbitalsArgument) -> Geometry: - r""" Retain only a subset of the orbitals on `atoms` according to `orbitals` + def sub_orbital(self, atoms: AtomsArgument, orbitals: OrbitalsArgument) -> Geometry: + r"""Retain only a subset of the orbitals on `atoms` according to `orbitals` This allows one to retain only a given subset of geometry. @@ -1945,12 +1999,14 @@ def sub_orbital(self, atoms: AtomsArgument, orbitals = np.sort(orbitals) if len(orbitals) == 0: - raise ValueError(f"{self.__class__.__name__}.sub_orbital trying to retain 0 orbitals on a given atom. This is not allowed!") + raise ValueError( + f"{self.__class__.__name__}.sub_orbital trying to retain 0 orbitals on a given atom. This is not allowed!" + ) # create the new atom new_atom = old_atom.sub(orbitals) # Rename the new-atom to <>_1_2 for orbital == [1, 2] - new_atom._tag += '_' + '_'.join(map(str, orbitals)) + new_atom._tag += "_" + "_".join(map(str, orbitals)) # There are now 2 cases. # 1. we replace all atoms of a given specie @@ -1958,7 +2014,7 @@ def sub_orbital(self, atoms: AtomsArgument, if len(atoms) == old_atom_count: # We catch the warning about reducing the number of orbitals! with warnings.catch_warnings(): - warnings.filterwarnings('ignore') + warnings.filterwarnings("ignore") # this is in-place operation and we don't need to worry about geom.atoms.replace_atom(old_atom, new_atom) @@ -1977,7 +2033,7 @@ def sub_orbital(self, atoms: AtomsArgument, return geom def remove(self, atoms: AtomsArgument) -> Geometry: - """ Remove atoms from the geometry. + """Remove atoms from the geometry. Indices passed *MUST* be unique. @@ -1998,10 +2054,10 @@ def remove(self, atoms: AtomsArgument) -> Geometry: atoms = np.delete(_a.arangei(self.na), atoms) return self.sub(atoms) - def remove_orbital(self, - atoms: AtomsArgument, - orbitals: OrbitalsArgument) -> Geometry: - """ Remove a subset of orbitals on `atoms` according to `orbitals` + def remove_orbital( + self, atoms: AtomsArgument, orbitals: OrbitalsArgument + ) -> Geometry: + """Remove a subset of orbitals on `atoms` according to `orbitals` For more detailed examples, please see the equivalent (but opposite) method `sub_orbital`. @@ -2048,12 +2104,8 @@ def remove_orbital(self, # now call sub_orbital return self.sub_orbital(atoms, orbitals) - def unrepeat(self, - reps: int, - axis: int, - *args, - **kwargs) -> Geometry: - """ Unrepeats the geometry similarly as `untile` + def unrepeat(self, reps: int, axis: int, *args, **kwargs) -> Geometry: + """Unrepeats the geometry similarly as `untile` Please see `untile` for argument details, the algorithm and arguments are the same however, this is the opposite of `repeat`. @@ -2061,13 +2113,15 @@ def unrepeat(self, atoms = np.arange(self.na).reshape(-1, reps).T.ravel() return self.sub(atoms).untile(reps, axis, *args, **kwargs) - def untile(self, - reps: int, - axis: int, - segment: int=0, - rtol: float=1e-4, - atol: float=1e-4) -> Geometry: - """ A subset of atoms from the geometry by cutting the geometry into `reps` parts along the direction `axis`. + def untile( + self, + reps: int, + axis: int, + segment: int = 0, + rtol: float = 1e-4, + atol: float = 1e-4, + ) -> Geometry: + """A subset of atoms from the geometry by cutting the geometry into `reps` parts along the direction `axis`. This will effectively change the unit-cell in the `axis` as-well as removing ``self.na/reps`` atoms. @@ -2110,9 +2164,11 @@ def untile(self, tile : opposite method of this """ if self.na % reps != 0: - raise ValueError(f'{self.__class__.__name__}.untile ' - f'cannot be cut into {reps} different ' - 'pieces. Please check your geometry and input.') + raise ValueError( + f"{self.__class__.__name__}.untile " + f"cannot be cut into {reps} different " + "pieces. Please check your geometry and input." + ) # Truncate to the correct segments lseg = segment % reps # Cut down cell @@ -2123,14 +2179,14 @@ def untile(self, new = self.sub(_a.arangei(off, off + n)) new.set_lattice(lattice) if not np.allclose(new.tile(reps, axis).xyz, self.xyz, rtol=rtol, atol=atol): - warn("The cut structure cannot be re-created by tiling\n" - "The tolerance between the coordinates can be altered using rtol, atol") + warn( + "The cut structure cannot be re-created by tiling\n" + "The tolerance between the coordinates can be altered using rtol, atol" + ) return new - def tile(self, - reps: int, - axis: int) -> Geometry: - """ Tile the geometry to create a bigger one + def tile(self, reps: int, axis: int) -> Geometry: + """Tile the geometry to create a bigger one The atomic indices are retained for the base structure. @@ -2170,7 +2226,9 @@ def tile(self, untile : opposite method of this """ if reps < 1: - raise ValueError(f'{self.__class__.__name__}.tile requires a repetition above 0') + raise ValueError( + f"{self.__class__.__name__}.tile requires a repetition above 0" + ) lattice = self.lattice.tile(reps, axis) @@ -2189,10 +2247,8 @@ def tile(self, # will also expand via tiling) return self.__class__(xyz, atoms=self.atoms.tile(reps), lattice=lattice) - def repeat(self, - reps: int, - axis: int) -> Geometry: - """ Create a repeated geometry + def repeat(self, reps: int, axis: int) -> Geometry: + """Create a repeated geometry The atomic indices are *NOT* retained from the base structure. @@ -2243,7 +2299,9 @@ def repeat(self, tile : equivalent but different ordering of final structure """ if reps < 1: - raise ValueError(f'{self.__class__.__name__}.repeat requires a repetition above 0') + raise ValueError( + f"{self.__class__.__name__}.repeat requires a repetition above 0" + ) lattice = self.lattice.repeat(reps, axis) @@ -2262,8 +2320,8 @@ def repeat(self, # Create the geometry and return it return self.__class__(xyz, atoms=self.atoms.repeat(reps), lattice=lattice) - def __mul__(self, m, method='tile') -> Geometry: - """ Implement easy tile/repeat function + def __mul__(self, m, method="tile") -> Geometry: + """Implement easy tile/repeat function Parameters ---------- @@ -2307,12 +2365,7 @@ def __mul__(self, m, method='tile') -> Geometry: return self * m[0] # Look-up table - method_tbl = { - 'r': 'repeat', - 'repeat': 'repeat', - 't': 'tile', - 'tile': 'tile' - } + method_tbl = {"r": "repeat", "repeat": "repeat", "t": "tile", "tile": "tile"} # Determine the type if len(m) == 2: @@ -2345,15 +2398,13 @@ def __mul__(self, m, method='tile') -> Geometry: return g def __rmul__(self, m) -> Geometry: - """ Default to repeating the atomic structure """ - return self.__mul__(m, 'repeat') + """Default to repeating the atomic structure""" + return self.__mul__(m, "repeat") - def angle(self, - atoms: AtomsArgument, - dir=(1., 0, 0), - ref=None, - rad: bool=False) -> Union[float, ndarray]: - r""" The angle between atom `atoms` and the direction `dir`, with possibility of a reference coordinate `ref` + def angle( + self, atoms: AtomsArgument, dir=(1.0, 0, 0), ref=None, rad: bool = False + ) -> Union[float, ndarray]: + r"""The angle between atom `atoms` and the direction `dir`, with possibility of a reference coordinate `ref` The calculated angle can be written as this @@ -2379,7 +2430,7 @@ def angle(self, """ xi = self.axyz(atoms) if isinstance(dir, (str, Integral)): - dir = direction(dir, abc=self.cell, xyz=np.diag([1]*3)) + dir = direction(dir, abc=self.cell, xyz=np.diag([1] * 3)) else: dir = _a.asarrayd(dir) # Normalize so we don't have to have this in the @@ -2400,17 +2451,22 @@ def angle(self, return ang return np.degrees(ang) - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def rotate(self, - angle: float, - v, - origin=None, - atoms: Optional[AtomsArgument]=None, - rad: bool=False, - what: Optional[str]=None) -> Geometry: - r""" Rotate geometry around vector and return a new geometry + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def rotate( + self, + angle: float, + v, + origin=None, + atoms: Optional[AtomsArgument] = None, + rad: bool = False, + what: Optional[str] = None, + ) -> Geometry: + r"""Rotate geometry around vector and return a new geometry Per default will the entire geometry be rotated, such that everything is aligned as before rotation. @@ -2459,7 +2515,7 @@ def rotate(self, Lattice.rotate : rotation passed to the contained supercell """ if origin is None: - origin = [0., 0., 0.] + origin = [0.0, 0.0, 0.0] elif isinstance(origin, Integral): origin = self.axyz(origin) origin = _a.asarray(origin) @@ -2478,7 +2534,11 @@ def rotate(self, if isinstance(v, Integral): v = direction(v, abc=self.cell, xyz=np.diag([1, 1, 1])) elif isinstance(v, str): - v = reduce(lambda a, b: a + direction(b, abc=self.cell, xyz=np.diag([1, 1, 1])), v, 0) + v = reduce( + lambda a, b: a + direction(b, abc=self.cell, xyz=np.diag([1, 1, 1])), + v, + 0, + ) # Ensure the normal vector is normalized... (flatten == copy) vn = _a.asarrayd(v).flatten() @@ -2491,7 +2551,7 @@ def rotate(self, xyz = np.copy(self.xyz) idx = [] - for i, d in enumerate('xyz'): + for i, d in enumerate("xyz"): if d in what: idx.append(i) @@ -2500,7 +2560,7 @@ def rotate(self, q = Quaternion(angle, vn, rad=rad) q /= q.norm() # subtract and add origin, before and after rotation - rotated = (q.rotate(xyz[atoms] - origin) + origin) + rotated = q.rotate(xyz[atoms] - origin) + origin # get which coordinates to rotate for i in idx: xyz[atoms, i] = rotated[:, i] @@ -2508,16 +2568,20 @@ def rotate(self, return self.__class__(xyz, atoms=self.atoms.copy(), lattice=lattice) def rotate_miller(self, m, v) -> Geometry: - """ Align Miller direction along ``v`` + """Align Miller direction along ``v`` Rotate geometry and cell such that the Miller direction points along the Cartesian vector ``v``. """ # Create normal vector to miller direction and cartesian # direction - cp = _a.arrayd([m[1] * v[2] - m[2] * v[1], - m[2] * v[0] - m[0] * v[2], - m[0] * v[1] - m[1] * v[0]]) + cp = _a.arrayd( + [ + m[1] * v[2] - m[2] * v[1], + m[2] * v[0] - m[0] * v[2], + m[0] * v[1] - m[1] * v[0], + ] + ) cp /= fnorm(cp) lm = _a.arrayd(m) @@ -2529,11 +2593,10 @@ def rotate_miller(self, m, v) -> Geometry: a = acos(np.sum(lm * lv)) return self.rotate(a, cp, rad=True) - def translate(self, - v, - atoms: Optional[AtomsArgument]=None, - cell: bool=False) -> Geometry: - """ Translates the geometry by `v` + def translate( + self, v, atoms: Optional[AtomsArgument] = None, cell: bool = False + ) -> Geometry: + """Translates the geometry by `v` One can translate a subset of the atoms by supplying `atoms`. @@ -2558,11 +2621,12 @@ def translate(self, if cell: g.set_lattice(g.lattice.translate(v)) return g + move = translate - def translate2uc(self, - atoms: Optional[AtomsArgument]=None, - axes=None) -> Geometry: + def translate2uc( + self, atoms: Optional[AtomsArgument] = None, axes=None + ) -> Geometry: """Translates atoms in the geometry into the unit cell One can translate a subset of the atoms or axes by appropriate arguments. @@ -2598,7 +2662,9 @@ def translate2uc(self, if axes: axes = (0, 1, 2) else: - raise ValueError("translate2uc with a bool argument can only be True to signal all axes") + raise ValueError( + "translate2uc with a bool argument can only be True to signal all axes" + ) fxyz = self.fxyz # move to unit-cell @@ -2612,10 +2678,8 @@ def translate2uc(self, g.xyz[idx] = fxyz[idx] @ self.cell return g - def swap(self, - atoms_a: AtomsArgument, - atoms_b: AtomsArgument) -> Geometry: - """ Swap a set of atoms in the geometry and return a new one + def swap(self, atoms_a: AtomsArgument, atoms_b: AtomsArgument) -> Geometry: + """Swap a set of atoms in the geometry and return a new one This can be used to reorder elements of a geometry. @@ -2631,13 +2695,14 @@ def swap(self, xyz = np.copy(self.xyz) xyz[atoms_a, :] = self.xyz[atoms_b, :] xyz[atoms_b, :] = self.xyz[atoms_a, :] - return self.__class__(xyz, atoms=self.atoms.swap(atoms_a, atoms_b), lattice=self.lattice.copy()) + return self.__class__( + xyz, atoms=self.atoms.swap(atoms_a, atoms_b), lattice=self.lattice.copy() + ) - def swapaxes(self, - axes_a: Union[int, str], - axes_b: Union[int, str], - what: str="abc") -> Geometry: - """ Swap the axes components by either lattice vectors (only cell), or Cartesian coordinates + def swapaxes( + self, axes_a: Union[int, str], axes_b: Union[int, str], what: str = "abc" + ) -> Geometry: + """Swap the axes components by either lattice vectors (only cell), or Cartesian coordinates See `Lattice.swapaxes` for details. @@ -2706,12 +2771,14 @@ def swapaxes(self, if aidx < 3: idx[aidx], idx[bidx] = idx[bidx], idx[aidx] - return self.__class__(self.xyz[:, idx].copy(), atoms=self.atoms.copy(), lattice=lattice) + return self.__class__( + self.xyz[:, idx].copy(), atoms=self.atoms.copy(), lattice=lattice + ) - def center(self, - atoms: Optional[AtomsArgument]=None, - what: str="xyz") -> ndarray: - """ Returns the center of the geometry + def center( + self, atoms: Optional[AtomsArgument] = None, what: str = "xyz" + ) -> ndarray: + """Returns the center of the geometry By specifying `what` one can control whether it should be: @@ -2748,7 +2815,7 @@ def center(self, # construct angles avg_cos = (mass @ np.cos(theta)) / sum_mass avg_sin = (mass @ np.sin(theta)) / sum_mass - avg_theta = np.arctan2(-avg_sin, -avg_cos) / (2*np.pi) + 0.5 + avg_theta = np.arctan2(-avg_sin, -avg_cos) / (2 * np.pi) + 0.5 return avg_theta @ g.lattice.cell if "mass" == what: @@ -2761,13 +2828,14 @@ def center(self, if what in ("xyz", "position"): return np.mean(g.xyz, axis=0) - raise ValueError(f"{self.__class__.__name__}.center could not understand option 'what' got {what}") + raise ValueError( + f"{self.__class__.__name__}.center could not understand option 'what' got {what}" + ) - def append(self, - other: LatticeOrGeometryLike, - axis: int, - offset="none") -> Geometry: - """ Appends two structures along `axis` + def append( + self, other: LatticeOrGeometryLike, axis: int, offset="none" + ) -> Geometry: + """Appends two structures along `axis` This will automatically add the ``self.cell[axis,:]`` to all atomic coordiates in the `other` structure before appending. @@ -2816,7 +2884,9 @@ def append(self, if offset == "none": offset = [0, 0, 0] else: - raise ValueError(f"{self.__class__.__name__}.append requires offset to be (3,) for supercell input") + raise ValueError( + f"{self.__class__.__name__}.append requires offset to be (3,) for supercell input" + ) xyz += _a.asarray(offset) else: @@ -2824,15 +2894,17 @@ def append(self, other = self.new(other) if isinstance(offset, str): offset = offset.lower() - if offset == 'none': + if offset == "none": offset = self.cell[axis, :] - elif offset == 'min': + elif offset == "min": # We want to align at the minimum position along the `axis` min_f = self.fxyz[:, axis].min() min_other_f = dot(other.xyz, self.icell.T)[:, axis].min() offset = self.cell[axis, :] * (1 + min_f - min_other_f) else: - raise ValueError(f'{self.__class__.__name__}.append requires align keyword to be one of [none, min, (3,)]') + raise ValueError( + f"{self.__class__.__name__}.append requires align keyword to be one of [none, min, (3,)]" + ) else: offset = self.cell[axis, :] + _a.asarray(offset) @@ -2843,11 +2915,10 @@ def append(self, return self.__class__(xyz, atoms=atoms, lattice=lattice, names=names) - def prepend(self, - other: LatticeOrGeometryLike, - axis: int, - offset="none") -> Geometry: - """ Prepend two structures along `axis` + def prepend( + self, other: LatticeOrGeometryLike, axis: int, offset="none" + ) -> Geometry: + """Prepend two structures along `axis` This will automatically add the ``self.cell[axis,:]`` to all atomic coordiates in the `other` structure before appending. @@ -2896,7 +2967,9 @@ def prepend(self, if offset == "none": offset = [0, 0, 0] else: - raise ValueError(f"{self.__class__.__name__}.prepend requires offset to be (3,) for supercell input") + raise ValueError( + f"{self.__class__.__name__}.prepend requires offset to be (3,) for supercell input" + ) xyz += _a.arrayd(offset) else: @@ -2904,15 +2977,17 @@ def prepend(self, other = self.new(other) if isinstance(offset, str): offset = offset.lower() - if offset == 'none': + if offset == "none": offset = other.cell[axis, :] - elif offset == 'min': + elif offset == "min": # We want to align at the minimum position along the `axis` min_f = other.fxyz[:, axis].min() min_other_f = dot(self.xyz, other.icell.T)[:, axis].min() offset = other.cell[axis, :] * (1 + min_f - min_other_f) else: - raise ValueError(f'{self.__class__.__name__}.prepend requires align keyword to be one of [none, min, (3,)]') + raise ValueError( + f"{self.__class__.__name__}.prepend requires align keyword to be one of [none, min, (3,)]" + ) else: offset = other.cell[axis, :] + _a.asarray(offset) @@ -2923,10 +2998,8 @@ def prepend(self, return self.__class__(xyz, atoms=atoms, lattice=lattice, names=names) - def add(self, - other: LatticeOrGeometryLike, - offset=(0, 0, 0)) -> Geometry: - """ Merge two geometries (or a Geometry and Lattice) by adding the two atoms together + def add(self, other: LatticeOrGeometryLike, offset=(0, 0, 0)) -> Geometry: + """Merge two geometries (or a Geometry and Lattice) by adding the two atoms together If `other` is a Geometry only the atoms gets added, to also add the supercell vectors simply do ``geom.add(other).add(other.lattice)``. @@ -2959,10 +3032,8 @@ def add(self, names = self._names.merge(other._names, offset=len(self)) return self.__class__(xyz, atoms=atoms, lattice=lattice, names=names) - def add_vacuum(self, - vacuum: float, - axis: int) -> Geometry: - """ Add vacuum along the `axis` lattice vector + def add_vacuum(self, vacuum: float, axis: int) -> Geometry: + """Add vacuum along the `axis` lattice vector When the vacuum is bigger than the maximum orbital ranges the number of supercells along that axis will be truncated to 1 (de-couple @@ -2988,10 +3059,8 @@ def add_vacuum(self, new.lattice.set_nsc(nsc) return new - def insert(self, - atom: int, - other: GeometryLike) -> Geometry: - """ Inserts other atoms right before index + def insert(self, atom: int, other: GeometryLike) -> Geometry: + """Inserts other atoms right before index We insert the `geometry` `Geometry` before `atom`. Note that this will not change the unit cell. @@ -3012,14 +3081,16 @@ def insert(self, """ atom = self._sanitize_atoms(atom) if atom.size > 1: - raise ValueError(f"{self.__class__.__name__}.insert requires only 1 atomic index for insertion.") + raise ValueError( + f"{self.__class__.__name__}.insert requires only 1 atomic index for insertion." + ) other = self.new(other) xyz = np.insert(self.xyz, atom, other.xyz, axis=0) atoms = self.atoms.insert(atom, other.atoms) return self.__class__(xyz, atoms, lattice=self.lattice.copy()) def __add__(self, b) -> Geometry: - """ Merge two geometries (or geometry and supercell) + """Merge two geometries (or geometry and supercell) Parameters ---------- @@ -3049,7 +3120,7 @@ def __add__(self, b) -> Geometry: return self.append(b[0], b[1]) def __radd__(self, b) -> Geometry: - """ Merge two geometries (or geometry and supercell) + """Merge two geometries (or geometry and supercell) Parameters ---------- @@ -3078,13 +3149,15 @@ def __radd__(self, b) -> Geometry: return b.add(self) return self + b - def attach(self, - atom: int, - other: GeometryLike, - other_atom: int, - dist='calc', - axis: Optional[int]=None) -> Geometry: - """ Attaches another `Geometry` at the `atom` index with respect to `other_atom` using different methods. + def attach( + self, + atom: int, + other: GeometryLike, + other_atom: int, + dist="calc", + axis: Optional[int] = None, + ) -> Geometry: + """Attaches another `Geometry` at the `atom` index with respect to `other_atom` using different methods. The attached geometry will be inserted at the end of the geometry via `add`. @@ -3115,17 +3188,21 @@ def attach(self, if isinstance(dist, Real): # We have a single rational number if axis is None: - raise ValueError(f"{self.__class__.__name__}.attach, `axis` has not been specified, please specify the axis when using a distance") + raise ValueError( + f"{self.__class__.__name__}.attach, `axis` has not been specified, please specify the axis when using a distance" + ) # Now calculate the vector that we should have # between the atoms v = self.cell[axis, :] - v = v / (v ** 2).sum() ** 0.5 * dist + v = v / (v**2).sum() ** 0.5 * dist elif isinstance(dist, str): # We have a single rational number if axis is None: - raise ValueError(f"{self.__class__.__name__}.attach, `axis` has not been specified, please specify the axis when using a distance") + raise ValueError( + f"{self.__class__.__name__}.attach, `axis` has not been specified, please specify the axis when using a distance" + ) # This is the empirical distance between the atoms d = self.atoms[atom].radius(dist) + other.atoms[other_atom].radius(dist) @@ -3134,7 +3211,7 @@ def attach(self, else: v = np.array(axis) - v = v / (v ** 2).sum() ** 0.5 * d + v = v / (v**2).sum() ** 0.5 * d else: # The user *must* have supplied a vector @@ -3149,11 +3226,10 @@ def attach(self, # so we will do nothing... return self.add(o) - def replace(self, - atoms: AtomsArgument, - other: GeometryLike, - offset=None) -> Geometry: - """ Create a new geometry from `self` and replace `atoms` with `other` + def replace( + self, atoms: AtomsArgument, other: GeometryLike, offset=None + ) -> Geometry: + """Create a new geometry from `self` and replace `atoms` with `other` Parameters ---------- @@ -3182,8 +3258,8 @@ def replace(self, out._atoms = out.atoms.insert(index, other.atoms) return out - def reverse(self, atoms: Optional[AtomsArgument]=None) -> Geometry: - """ Returns a reversed geometry + def reverse(self, atoms: Optional[AtomsArgument] = None) -> Geometry: + """Returns a reversed geometry Also enables reversing a subset of the atoms. @@ -3199,10 +3275,14 @@ def reverse(self, atoms: Optional[AtomsArgument]=None) -> Geometry: atoms = self._sanitize_atoms(atoms).ravel() xyz = np.copy(self.xyz) xyz[atoms, :] = self.xyz[atoms[::-1], :] - return self.__class__(xyz, atoms=self.atoms.reverse(atoms), lattice=self.lattice.copy()) + return self.__class__( + xyz, atoms=self.atoms.reverse(atoms), lattice=self.lattice.copy() + ) - def mirror(self, method, atoms: Optional[AtomsArgument]=None, point=(0, 0, 0)) -> Geometry: - r""" Mirrors the atomic coordinates about a plane given by its normal vector + def mirror( + self, method, atoms: Optional[AtomsArgument] = None, point=(0, 0, 0) + ) -> Geometry: + r"""Mirrors the atomic coordinates about a plane given by its normal vector This will typically move the atomic coordinates outside of the unit-cell. This method should be used with care. @@ -3235,27 +3315,29 @@ def mirror(self, method, atoms: Optional[AtomsArgument]=None, point=(0, 0, 0)) - point = _a.asarrayd(point) if isinstance(method, str): - method = ''.join(sorted(method.lower())) - if method in ('z', 'xy'): + method = "".join(sorted(method.lower())) + if method in ("z", "xy"): method = _a.arrayd([0, 0, 1]) - elif method in ('x', 'yz'): + elif method in ("x", "yz"): method = _a.arrayd([1, 0, 0]) - elif method in ('y', 'xz'): + elif method in ("y", "xz"): method = _a.arrayd([0, 1, 0]) - elif method == 'a': + elif method == "a": method = self.cell[0] - elif method == 'b': + elif method == "b": method = self.cell[1] - elif method == 'c': + elif method == "c": method = self.cell[2] - elif method == 'ab': + elif method == "ab": method = cross3(self.cell[0], self.cell[1]) - elif method == 'ac': + elif method == "ac": method = cross3(self.cell[0], self.cell[2]) - elif method == 'bc': + elif method == "bc": method = cross3(self.cell[1], self.cell[2]) else: - raise ValueError(f"{self.__class__.__name__}.mirror unrecognized 'method' value") + raise ValueError( + f"{self.__class__.__name__}.mirror unrecognized 'method' value" + ) # it has to be an array of length 3 # Mirror about a user defined vector @@ -3271,7 +3353,7 @@ def mirror(self, method, atoms: Optional[AtomsArgument]=None, point=(0, 0, 0)) - g.xyz[atoms, :] -= vp.reshape(-1, 1) * method.reshape(1, 3) return g - def axyz(self, atoms: Optional[AtomsArgument]=None, isc=None) -> ndarray: + def axyz(self, atoms: Optional[AtomsArgument] = None, isc=None) -> ndarray: """Return the atomic coordinates in the supercell of a given atom. The ``Geometry[...]`` slicing is calling this function with appropriate options. @@ -3311,11 +3393,8 @@ def axyz(self, atoms: Optional[AtomsArgument]=None, isc=None) -> ndarray: # Neither of atoms, or isc are `None`, we add the offset to all coordinates return self.axyz(atoms) + self.lattice.offset(isc) - def scale(self, - scale, - what:str ="abc", - scale_atoms: bool=True) -> Geometry: - """ Scale coordinates and unit-cell to get a new geometry with proper scaling + def scale(self, scale, what: str = "abc", scale_atoms: bool = True) -> Geometry: + """Scale coordinates and unit-cell to get a new geometry with proper scaling Parameters ---------- @@ -3357,7 +3436,9 @@ def scale(self, scaled_span = scaled_verts.max(axis=0) - scaled_verts.min(axis=0) max_scale = (scaled_span / prev_span).max() else: - raise ValueError(f"{self.__class__.__name__}.scale got wrong what argument, must be one of abc|xyz") + raise ValueError( + f"{self.__class__.__name__}.scale got wrong what argument, must be one of abc|xyz" + ) if scale_atoms: # Atoms are rescaled to the maximum scale factor @@ -3367,14 +3448,16 @@ def scale(self, return self.__class__(xyz, atoms=atoms, lattice=lattice) - def within_sc(self, - shapes, - isc=None, - atoms: Optional[AtomsArgument]=None, - atoms_xyz=None, - ret_xyz: bool=False, - ret_rij: bool=False): - """ Indices of atoms in a given supercell within a given shape from a given coordinate + def within_sc( + self, + shapes, + isc=None, + atoms: Optional[AtomsArgument] = None, + atoms_xyz=None, + ret_xyz: bool = False, + ret_rij: bool = False, + ): + """Indices of atoms in a given supercell within a given shape from a given coordinate This returns a set of atomic indices which are within a sphere of radius ``R``. @@ -3529,15 +3612,17 @@ def within_sc(self, return ret return ret[0] - def close_sc(self, - xyz_ia, - isc=(0, 0, 0), - R=None, - atoms: Optional[AtomsArgument]=None, - atoms_xyz=None, - ret_xyz=False, - ret_rij=False): - """ Indices of atoms in a given supercell within a given radius from a given coordinate + def close_sc( + self, + xyz_ia, + isc=(0, 0, 0), + R=None, + atoms: Optional[AtomsArgument] = None, + atoms_xyz=None, + ret_xyz=False, + ret_rij=False, + ): + """Indices of atoms in a given supercell within a given radius from a given coordinate This returns a set of atomic indices which are within a sphere of radius `R`. @@ -3591,10 +3676,12 @@ def close_sc(self, # Maximum distance queried max_R = R[-1] if atoms is not None and max_R > maxR + 0.1: - warn(f"{self.__class__.__name__}.close_sc has been passed an 'atoms' argument " - "together with an R value larger than the orbital ranges. " - "If used together with 'sparse-matrix.construct' this can result in wrong couplings.", - register=True) + warn( + f"{self.__class__.__name__}.close_sc has been passed an 'atoms' argument " + "together with an R value larger than the orbital ranges. " + "If used together with 'sparse-matrix.construct' this can result in wrong couplings.", + register=True, + ) # Convert to actual array if atoms is not None: @@ -3677,8 +3764,10 @@ def close_sc(self, return ret[0] if not is_ascending(R): - raise ValueError(f"{self.__class__.__name__}.close_sc proximity checks for several " - "quantities at a time requires ascending R values.") + raise ValueError( + f"{self.__class__.__name__}.close_sc proximity checks for several " + "quantities at a time requires ascending R values." + ) # The more neigbours you wish to find the faster this becomes # We only do "one" heavy duty search, @@ -3699,33 +3788,30 @@ def close_sc(self, # Notice that this sub-space reduction will never # allow the same indice to be in two ranges (due to # numerics) - tidx = indices_gt_le(d, R[i-1], R[i]) + tidx = indices_gt_le(d, R[i - 1], R[i]) r_app(atoms[tidx]) r_appx(xa[tidx]) r_appd(d[tidx]) elif ret_xyz: for i in range(1, len(R)): - tidx = indices_gt_le(d, R[i-1], R[i]) + tidx = indices_gt_le(d, R[i - 1], R[i]) r_app(atoms[tidx]) r_appx(xa[tidx]) elif ret_rij: for i in range(1, len(R)): - tidx = indices_gt_le(d, R[i-1], R[i]) + tidx = indices_gt_le(d, R[i - 1], R[i]) r_app(atoms[tidx]) r_appd(d[tidx]) else: for i in range(1, len(R)): - tidx = indices_gt_le(d, R[i-1], R[i]) + tidx = indices_gt_le(d, R[i - 1], R[i]) r_app(atoms[tidx]) if ret_xyz or ret_rij: return ret return ret[0] - def bond_correct(self, - ia: int, - atoms: AtomsArgument, - method='calc') -> None: + def bond_correct(self, ia: int, atoms: AtomsArgument, method="calc") -> None: """Corrects the bond between `ia` and the `atoms`. Corrects the bond-length between atom `ia` and `atoms` in such @@ -3754,12 +3840,12 @@ def bond_correct(self, algo = -1 if algo >= 0: - # We have a single atom # Get bond length in the closest direction # A bond-length HAS to be below 10 - atoms, c, d = self.close(ia, R=(0.1, 10.), atoms=algo, - ret_xyz=True, ret_rij=True) + atoms, c, d = self.close( + ia, R=(0.1, 10.0), atoms=algo, ret_xyz=True, ret_rij=True + ) i = np.argmin(d[1]) # Convert to unitcell atom (and get the one atom) atoms = self.sc2uc(atoms[1][i]) @@ -3774,24 +3860,26 @@ def bond_correct(self, rad = float(method) except Exception: # get radius - rad = self.atoms[atoms].radius(method) \ - + self.atoms[ia].radius(method) + rad = self.atoms[atoms].radius(method) + self.atoms[ia].radius(method) # Update the coordinate self.xyz[ia, :] = c + bv / d * rad else: raise NotImplementedError( - 'Changing bond-length dependent on several lacks implementation.') - - def within(self, - shapes, - atoms: Optional[AtomsArgument]=None, - atoms_xyz=None, - ret_xyz: bool=False, - ret_rij: bool=False, - ret_isc: bool=False): - """ Indices of atoms in the entire supercell within a given shape from a given coordinate + "Changing bond-length dependent on several lacks implementation." + ) + + def within( + self, + shapes, + atoms: Optional[AtomsArgument] = None, + atoms_xyz=None, + ret_xyz: bool = False, + ret_rij: bool = False, + ret_isc: bool = False, + ): + """Indices of atoms in the entire supercell within a given shape from a given coordinate This heavily relies on the `within_sc` method. @@ -3856,12 +3944,16 @@ def isc_tile(isc, n): return tile(isc.reshape(1, -1), (n, 1)) for s in range(self.n_s): - na = self.na * s isc = self.lattice.sc_off[s, :] - sret = self.within_sc(shapes, self.lattice.sc_off[s, :], - atoms=atoms, atoms_xyz=atoms_xyz, - ret_xyz=ret_xyz, ret_rij=ret_rij) + sret = self.within_sc( + shapes, + self.lattice.sc_off[s, :], + atoms=atoms, + atoms_xyz=atoms_xyz, + ret_xyz=ret_xyz, + ret_rij=ret_rij, + ) if listify: # This is to "fake" the return @@ -3873,11 +3965,17 @@ def isc_tile(isc, n): for i, x in enumerate(sret[0]): ret[0][i] = concatenate((ret[0][i], x + na), axis=0) if ret_xyz: - ret[ixyz][i] = concatenate((ret[ixyz][i], sret[ixyz][i]), axis=0) + ret[ixyz][i] = concatenate( + (ret[ixyz][i], sret[ixyz][i]), axis=0 + ) if ret_rij: - ret[irij][i] = concatenate((ret[irij][i], sret[irij][i]), axis=0) + ret[irij][i] = concatenate( + (ret[irij][i], sret[irij][i]), axis=0 + ) if ret_isc: - ret[iisc][i] = concatenate((ret[iisc][i], isc_tile(isc, len(x))), axis=0) + ret[iisc][i] = concatenate( + (ret[iisc][i], isc_tile(isc, len(x))), axis=0 + ) elif len(sret[0]) > 0: # We can add it to the list (nshapes == 1) # We add the atomic offset for the supercell index @@ -3887,7 +3985,9 @@ def isc_tile(isc, n): if ret_rij: ret[irij][0] = concatenate((ret[irij][0], sret[irij]), axis=0) if ret_isc: - ret[iisc][0] = concatenate((ret[iisc][0], isc_tile(isc, len(sret[0]))), axis=0) + ret[iisc][0] = concatenate( + (ret[iisc][0], isc_tile(isc, len(sret[0]))), axis=0 + ) if nshapes == 1: if n_ret == 0: @@ -3898,15 +3998,17 @@ def isc_tile(isc, n): return ret[0] return ret - def close(self, - xyz_ia, - R=None, - atoms: Optional[AtomsArgument]=None, - atoms_xyz=None, - ret_xyz: bool=False, - ret_rij: bool=False, - ret_isc: bool=False): - """ Indices of atoms in the entire supercell within a given radius from a given coordinate + def close( + self, + xyz_ia, + R=None, + atoms: Optional[AtomsArgument] = None, + atoms_xyz=None, + ret_xyz: bool = False, + ret_rij: bool = False, + ret_isc: bool = False, + ): + """Indices of atoms in the entire supercell within a given radius from a given coordinate This heavily relies on the `close_sc` method. @@ -3991,12 +4093,17 @@ def isc_tile(isc, n): return tile(isc.reshape(1, -1), (n, 1)) for s in range(self.n_s): - na = self.na * s isc = self.lattice.sc_off[s] - sret = self.close_sc(xyz_ia, isc, R=R, - atoms=atoms, atoms_xyz=atoms_xyz, - ret_xyz=ret_xyz, ret_rij=ret_rij) + sret = self.close_sc( + xyz_ia, + isc, + R=R, + atoms=atoms, + atoms_xyz=atoms_xyz, + ret_xyz=ret_xyz, + ret_rij=ret_rij, + ) if listify: # This is to "fake" the return @@ -4008,11 +4115,17 @@ def isc_tile(isc, n): for i, x in enumerate(sret[0]): ret[0][i] = concatenate((ret[0][i], x + na), axis=0) if ret_xyz: - ret[ixyz][i] = concatenate((ret[ixyz][i], sret[ixyz][i]), axis=0) + ret[ixyz][i] = concatenate( + (ret[ixyz][i], sret[ixyz][i]), axis=0 + ) if ret_rij: - ret[irij][i] = concatenate((ret[irij][i], sret[irij][i]), axis=0) + ret[irij][i] = concatenate( + (ret[irij][i], sret[irij][i]), axis=0 + ) if ret_isc: - ret[iisc][i] = concatenate((ret[iisc][i], isc_tile(isc, len(x))), axis=0) + ret[iisc][i] = concatenate( + (ret[iisc][i], isc_tile(isc, len(x))), axis=0 + ) elif len(sret[0]) > 0: # We can add it to the list (len(R) == 1) # We add the atomic offset for the supercell index @@ -4022,7 +4135,9 @@ def isc_tile(isc, n): if ret_rij: ret[irij][0] = concatenate((ret[irij][0], sret[irij]), axis=0) if ret_isc: - ret[iisc][0] = concatenate((ret[iisc][0], isc_tile(isc, len(sret[0]))), axis=0) + ret[iisc][0] = concatenate( + (ret[iisc][0], isc_tile(isc, len(sret[0]))), axis=0 + ) if nR == 1: if n_ret == 0: @@ -4033,9 +4148,9 @@ def isc_tile(isc, n): return ret[0] return ret - def a2transpose(self, - atoms1: AtomsArgument, - atoms2: Optional[AtomsArgument]=None) -> Tuple[ndarray, ndarray]: + def a2transpose( + self, atoms1: AtomsArgument, atoms2: Optional[AtomsArgument] = None + ) -> Tuple[ndarray, ndarray]: """Transposes connections from `atoms1` to `atoms2` such that supercell connections are transposed When handling supercell indices it is useful to get the *transposed* connection. I.e. if you have @@ -4080,12 +4195,14 @@ def a2transpose(self, atoms2 = self._sanitize_atoms(atoms2) if atoms1.size == atoms2.size: pass - elif atoms1.size == 1: # typical case where atoms1 is a single number + elif atoms1.size == 1: # typical case where atoms1 is a single number atoms1 = np.tile(atoms1, atoms2.size) elif atoms2.size == 1: atoms2 = np.tile(atoms2, atoms1.size) else: - raise ValueError(f"{self.__class__.__name__}.a2transpose only allows length 1 or same length arrays.") + raise ValueError( + f"{self.__class__.__name__}.a2transpose only allows length 1 or same length arrays." + ) # Now convert atoms na = self.na @@ -4097,10 +4214,10 @@ def a2transpose(self, atoms2 = atoms2 % na + sc_index(-isc1) * na return atoms2, atoms1 - def o2transpose(self, - orb1: OrbitalsArgument, - orb2: Optional[OrbitalsArgument]=None) -> Tuple[ndarray, ndarray]: - """ Transposes connections from `orb1` to `orb2` such that supercell connections are transposed + def o2transpose( + self, orb1: OrbitalsArgument, orb2: Optional[OrbitalsArgument] = None + ) -> Tuple[ndarray, ndarray]: + """Transposes connections from `orb1` to `orb2` such that supercell connections are transposed When handling supercell indices it is useful to get the *transposed* connection. I.e. if you have a connection from site ``i`` (in unit cell indices) to site ``J`` (in supercell indices) it may be @@ -4144,12 +4261,14 @@ def o2transpose(self, orb2 = self._sanitize_orbs(orb2) if orb1.size == orb2.size: pass - elif orb1.size == 1: # typical case where orb1 is a single number + elif orb1.size == 1: # typical case where orb1 is a single number orb1 = np.tile(orb1, orb2.size) elif orb2.size == 1: orb2 = np.tile(orb2, orb1.size) else: - raise ValueError(f"{self.__class__.__name__}.o2transpose only allows length 1 or same length arrays.") + raise ValueError( + f"{self.__class__.__name__}.o2transpose only allows length 1 or same length arrays." + ) # Now convert orbs no = self.no @@ -4161,9 +4280,7 @@ def o2transpose(self, orb2 = orb2 % no + sc_index(-isc1) * no return orb2, orb1 - def a2o(self, - atoms: AtomsArgument, - all: bool=False) -> ndarray: + def a2o(self, atoms: AtomsArgument, all: bool = False) -> ndarray: """ Returns an orbital index of the first orbital of said atom. This is particularly handy if you want to create @@ -4194,10 +4311,8 @@ def a2o(self, return _a.array_arange(ob, oe) - def o2a(self, - orbitals: OrbitalsArgument, - unique: bool=False) -> ndarray: - """ Atomic index corresponding to the orbital indicies. + def o2a(self, orbitals: OrbitalsArgument, unique: bool = False) -> ndarray: + """Atomic index corresponding to the orbital indicies. Note that this will preserve the super-cell offsets. @@ -4211,7 +4326,10 @@ def o2a(self, orbitals = self._sanitize_orbs(orbitals) if orbitals.ndim == 0: # must only be 1 number (an Integral) - return np.argmax(orbitals % self.no <= self.lasto) + (orbitals // self.no) * self.na + return ( + np.argmax(orbitals % self.no <= self.lasto) + + (orbitals // self.no) * self.na + ) isc, orbitals = np.divmod(_a.asarrayi(orbitals.ravel()), self.no) a = list_index_le(orbitals, self.lasto) @@ -4219,10 +4337,8 @@ def o2a(self, return np.unique(a + isc * self.na) return a + isc * self.na - def uc2sc(self, - atoms: AtomsArgument, - unique: bool=False) -> ndarray: - """ Returns atom from unit-cell indices to supercell indices, possibly removing dublicates + def uc2sc(self, atoms: AtomsArgument, unique: bool = False) -> ndarray: + """Returns atom from unit-cell indices to supercell indices, possibly removing dublicates Parameters ---------- @@ -4232,16 +4348,17 @@ def uc2sc(self, If True the returned indices are unique and sorted. """ atoms = self._sanitize_atoms(atoms) % self.na - atoms = (atoms.reshape(1, -1) + _a.arangei(self.n_s).reshape(-1, 1) * self.na).ravel() + atoms = ( + atoms.reshape(1, -1) + _a.arangei(self.n_s).reshape(-1, 1) * self.na + ).ravel() if unique: return np.unique(atoms) return atoms + auc2sc = uc2sc - def sc2uc(self, - atoms: AtomsArgument, - unique: bool=False) -> ndarray: - """ Returns atoms from supercell indices to unit-cell indices, possibly removing dublicates + def sc2uc(self, atoms: AtomsArgument, unique: bool = False) -> ndarray: + """Returns atoms from supercell indices to unit-cell indices, possibly removing dublicates Parameters ---------- @@ -4254,12 +4371,11 @@ def sc2uc(self, if unique: return np.unique(atoms) return atoms + asc2uc = sc2uc - def osc2uc(self, - orbitals: OrbitalsArgument, - unique: bool=False) -> ndarray: - """ Orbitals from supercell indices to unit-cell indices, possibly removing dublicates + def osc2uc(self, orbitals: OrbitalsArgument, unique: bool = False) -> ndarray: + """Orbitals from supercell indices to unit-cell indices, possibly removing dublicates Parameters ---------- @@ -4273,10 +4389,8 @@ def osc2uc(self, return np.unique(orbitals) return orbitals - def ouc2sc(self, - orbitals: OrbitalsArgument, - unique: bool=False) -> ndarray: - """ Orbitals from unit-cell indices to supercell indices, possibly removing dublicates + def ouc2sc(self, orbitals: OrbitalsArgument, unique: bool = False) -> ndarray: + """Orbitals from unit-cell indices to supercell indices, possibly removing dublicates Parameters ---------- @@ -4286,15 +4400,16 @@ def ouc2sc(self, If True the returned indices are unique and sorted. """ orbitals = self._sanitize_orbs(orbitals) % self.no - orbitals = (orbitals.reshape(1, *orbitals.shape) + - _a.arangei(self.n_s) - .reshape(-1, *([1] * orbitals.ndim)) * self.no).ravel() + orbitals = ( + orbitals.reshape(1, *orbitals.shape) + + _a.arangei(self.n_s).reshape(-1, *([1] * orbitals.ndim)) * self.no + ).ravel() if unique: return np.unique(orbitals) return orbitals def a2isc(self, atoms: AtomsArgument) -> ndarray: - """ Super-cell indices for a specific/list atom + """Super-cell indices for a specific/list atom Returns a vector of 3 numbers with integers. Any multi-dimensional input will be flattened before return. @@ -4341,13 +4456,16 @@ def o2sc(self, orbitals: OrbitalsArgument) -> ndarray: """ return self.lattice.offset(self.o2isc(orbitals)) - def __plot__(self, - axis=None, - lattice: bool=True, - axes=False, - atom_indices: bool=False, - *args, **kwargs): - """ Plot the geometry in a specified ``matplotlib.Axes`` object. + def __plot__( + self, + axis=None, + lattice: bool = True, + axes=False, + atom_indices: bool = False, + *args, + **kwargs, + ): + """Plot the geometry in a specified ``matplotlib.Axes`` object. Parameters ---------- @@ -4367,8 +4485,8 @@ def __plot__(self, colors = np.linspace(0, 1, num=self.atoms.nspecie, endpoint=False) colors = colors[self.atoms.specie] - if 's' in kwargs: - area = kwargs.pop('s') + if "s" in kwargs: + area = kwargs.pop("s") else: area = _a.arrayd(self.atoms.Z) area[:] *= 20 * np.pi / area.min() @@ -4378,7 +4496,7 @@ def __plot__(self, # Ensure we have a new 3D Axes3D if len(axis) == 3: - d['projection'] = '3d' + d["projection"] = "3d" # The Geometry determines the axes, then we pass it to supercell. axes = plt.get_axes(axes, **d) @@ -4390,30 +4508,31 @@ def __plot__(self, # Create short-hand xyz = self.xyz - if axes.__class__.__name__.startswith('Axes3D'): + if axes.__class__.__name__.startswith("Axes3D"): # We should plot in 3D plots axes.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], s=area, c=colors, alpha=0.8) - axes.set_zlabel('Ang') + axes.set_zlabel("Ang") if atom_indices: for i, loc in enumerate(xyz): - axes.text(loc[0], loc[1], loc[2], str(i), verticalalignment='bottom') + axes.text( + loc[0], loc[1], loc[2], str(i), verticalalignment="bottom" + ) else: axes.scatter(xyz[:, axis[0]], xyz[:, axis[1]], s=area, c=colors, alpha=0.8) if atom_indices: for i, loc in enumerate(xyz): - axes.text(loc[axis[0]], loc[axis[1]], str(i), verticalalignment='bottom') + axes.text( + loc[axis[0]], loc[axis[1]], str(i), verticalalignment="bottom" + ) - axes.set_xlabel('Ang') - axes.set_ylabel('Ang') + axes.set_xlabel("Ang") + axes.set_ylabel("Ang") return axes - def equal(self, - other: GeometryLike, - R: bool=True, - tol: float=1e-4) -> bool: - """ Whether two geometries are the same (optional not check of the orbital radius) + def equal(self, other: GeometryLike, R: bool = True, tol: float = 1e-4) -> bool: + """Whether two geometries are the same (optional not check of the orbital radius) Parameters ---------- @@ -4438,8 +4557,8 @@ def __eq__(self, other): def __ne__(self, other): return not (self == other) - def sparserij(self, dtype=np.float64, na_iR: int=1000, method='rand'): - """ Return the sparse matrix with all distances in the matrix + def sparserij(self, dtype=np.float64, na_iR: int = 1000, method="rand"): + """Return the sparse matrix with all distances in the matrix The sparse matrix will only be defined for the elements which have orbitals overlapping with other atoms. @@ -4464,6 +4583,7 @@ def sparserij(self, dtype=np.float64, na_iR: int=1000, method='rand'): distance : create a list of distances """ from .sparse_geometry import SparseAtom + rij = SparseAtom(self, nnzpr=20, dtype=dtype) # Get R @@ -4472,7 +4592,6 @@ def sparserij(self, dtype=np.float64, na_iR: int=1000, method='rand'): # Do the loop for ias, atoms in self.iter_block(iR=iR, method=method): - # Get all the indexed atoms... # This speeds up the searching for # coordinates... @@ -4480,18 +4599,22 @@ def sparserij(self, dtype=np.float64, na_iR: int=1000, method='rand'): # Loop the atoms inside for ia in ias: - idx, r = self.close(ia, R=R, atoms=atoms, atoms_xyz=atoms_xyz, ret_rij=True) - rij[ia, ia] = 0. + idx, r = self.close( + ia, R=R, atoms=atoms, atoms_xyz=atoms_xyz, ret_rij=True + ) + rij[ia, ia] = 0.0 rij[ia, idx[1]] = r[1] return rij - def distance(self, - atoms: Optional[AtomsArgument]=None, - R: Optional[float]=None, - tol: float=0.1, - method: str='average') -> Union[float, ndarray]: - """ Calculate the distances for all atoms in shells of radius `tol` within `max_R` + def distance( + self, + atoms: Optional[AtomsArgument] = None, + R: Optional[float] = None, + tol: float = 0.1, + method: str = "average", + ) -> Union[float, ndarray]: + """Calculate the distances for all atoms in shells of radius `tol` within `max_R` Parameters ---------- @@ -4548,25 +4671,25 @@ def distance(self, if R is None: R = self.maxR() if R < 0: - raise ValueError(f"{self.__class__.__name__}" - ".distance cannot determine the `R` parameter. " - "The internal `maxR()` is negative and thus not set. " - "Set an explicit value for `R`.") + raise ValueError( + f"{self.__class__.__name__}" + ".distance cannot determine the `R` parameter. " + "The internal `maxR()` is negative and thus not set. " + "Set an explicit value for `R`." + ) elif np.any(self.nsc > 1): maxR = fnorm(self.cell).max() # These loops could be leveraged if we look at angles... - for i, j, k in product([0, self.nsc[0] // 2], - [0, self.nsc[1] // 2], - [0, self.nsc[2] // 2]): + for i, j, k in product( + [0, self.nsc[0] // 2], [0, self.nsc[1] // 2], [0, self.nsc[2] // 2] + ): if i == 0 and j == 0 and k == 0: continue sc = [i, j, k] off = self.lattice.offset(sc) for ii, jj, kk in product([0, 1], [0, 1], [0, 1]): - o = self.cell[0] * ii + \ - self.cell[1] * jj + \ - self.cell[2] * kk + o = self.cell[0] * ii + self.cell[1] * jj + self.cell[2] * kk maxR = max(maxR, fnorm(off + o)) if R > maxR: @@ -4576,7 +4699,7 @@ def distance(self, tol = _a.asarrayd(tol).ravel() if len(tol) == 1: # Now we are in a position to determine the sizes - dR = _a.aranged(tol[0] * .5, R + tol[0] * .55, tol[0]) + dR = _a.aranged(tol[0] * 0.5, R + tol[0] * 0.55, tol[0]) else: dR = tol.copy() dR[0] *= 0.5 @@ -4588,11 +4711,11 @@ def distance(self, # Now finalize dR by ensuring all remaining segments are captured t = tol[-1] - dR = concatenate((dR, _a.aranged(dR[-1] + t, R + t * .55, t))) + dR = concatenate((dR, _a.aranged(dR[-1] + t, R + t * 0.55, t))) # Reduce to the largest value above R # This ensures that R, truly is the largest considered element - dR = dR[:(dR > R).nonzero()[0][0]+1] + dR = dR[: (dR > R).nonzero()[0][0] + 1] # Now we can figure out the list of atoms in each shell # First create the initial lists of shell atoms @@ -4609,19 +4732,24 @@ def distance(self, # Now parse all of the shells with the correct routine # First we grap the routine: if isinstance(method, str): - if method == 'median': + if method == "median": + def func(lst): return np.median(lst, overwrite_input=True) - elif method == 'mode': + elif method == "mode": from scipy.stats import mode + def func(lst): return mode(lst, keepdims=False)[0] + else: try: func = getattr(np, method) except Exception: - raise ValueError(f"{self.__class__.__name__}.distance `method` got wrong input value.") + raise ValueError( + f"{self.__class__.__name__}.distance `method` got wrong input value." + ) else: func = method @@ -4639,12 +4767,10 @@ def func(lst): return d - def within_inf(self, - lattice: Lattice, - periodic=None, - tol: float=1e-5, - origin=None) -> Tuple[ndarray, ndarray, ndarray]: - """ Find all atoms within a provided supercell + def within_inf( + self, lattice: Lattice, periodic=None, tol: float = 1e-5, origin=None + ) -> Tuple[ndarray, ndarray, ndarray]: + """Find all atoms within a provided supercell Note this function is rather different from `close` and `within`. Specifically this routine is returning *all* indices for the infinite @@ -4763,7 +4889,7 @@ def within_inf(self, # infinite supercell indices return self.sc2uc(idx), xyz, isc - def apply(self, data, func, mapper, axis: int=0, segments="atoms") -> ndarray: + def apply(self, data, func, mapper, axis: int = 0, segments="atoms") -> ndarray: r"""Apply a function `func` to the data along axis `axis` using the method specified This can be useful for applying conversions from orbital data to atomic data through @@ -4821,47 +4947,51 @@ def apply(self, data, func, mapper, axis: int=0, segments="atoms") -> ndarray: elif segments == "all": segments = range(data.shape[axis]) else: - raise ValueError(f"{self.__class__}.apply got wrong argument 'segments'={segments}") + raise ValueError( + f"{self.__class__}.apply got wrong argument 'segments'={segments}" + ) # handle the data new_data = [ # execute func on the segmented data func(np.take(data, mapper(segment), axis), axis=axis) # loop each segment - for segment in segments] + for segment in segments + ] return np.stack(new_data, axis=axis) # Create pickling routines def __getstate__(self): - """ Returns the state of this object """ + """Returns the state of this object""" d = self.lattice.__getstate__() - d['xyz'] = self.xyz - d['atoms'] = self.atoms.__getstate__() + d["xyz"] = self.xyz + d["atoms"] = self.atoms.__getstate__() return d def __setstate__(self, d): - """ Re-create the state of this object """ + """Re-create the state of this object""" lattice = Lattice([1, 1, 1]) lattice.__setstate__(d) atoms = Atoms() - atoms.__setstate__(d['atoms']) - self.__init__(d['xyz'], atoms=atoms, lattice=lattice) + atoms.__setstate__(d["atoms"]) + self.__init__(d["xyz"], atoms=atoms, lattice=lattice) @classmethod def _ArgumentParser_args_single(cls): - """ Returns the options for `Geometry.ArgumentParser` in case they are the only options """ - return {'limit_arguments': False, - 'short': True, - 'positional_out': True, - } + """Returns the options for `Geometry.ArgumentParser` in case they are the only options""" + return { + "limit_arguments": False, + "short": True, + "positional_out": True, + } # Hook into the Geometry class to create # an automatic ArgumentParser which makes actions # as the options are read. @default_ArgumentParser(description="Manipulate a Geometry object in sisl.") def ArgumentParser(self, p=None, *args, **kwargs): - """ Create and return a group of argument parsers which manipulates it self `Geometry`. + """Create and return a group of argument parsers which manipulates it self `Geometry`. Parameters ---------- @@ -4877,13 +5007,16 @@ def ArgumentParser(self, p=None, *args, **kwargs): positional_out : bool, optional If ``True``, adds a positional argument which acts as --out. This may be handy if only the geometry is in the argument list. """ - limit_args = kwargs.get('limit_arguments', True) - short = kwargs.get('short', False) + limit_args = kwargs.get("limit_arguments", True) + short = kwargs.get("short", False) if short: + def opts(*args): return args + else: + def opts(*args): return [arg for arg in args if arg.startswith("--")] @@ -4907,38 +5040,58 @@ def opts(*args): class Format(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._geom_fmt = value[0] - p.add_argument(*opts('--format'), action=Format, nargs=1, default='.8f', - help='Specify output format for coordinates.') + + p.add_argument( + *opts("--format"), + action=Format, + nargs=1, + default=".8f", + help="Specify output format for coordinates.", + ) class MoveOrigin(argparse.Action): def __call__(self, parser, ns, no_value, option_string=None): ns._geometry.xyz[:, :] -= np.amin(ns._geometry.xyz, axis=0)[None, :] - p.add_argument(*opts('--origin', '-O'), action=MoveOrigin, nargs=0, - help='Move all atoms such that the smallest value along each Cartesian direction will be at the origin.') + + p.add_argument( + *opts("--origin", "-O"), + action=MoveOrigin, + nargs=0, + help="Move all atoms such that the smallest value along each Cartesian direction will be at the origin.", + ) class MoveCenterOf(argparse.Action): def __call__(self, parser, ns, value, option_string=None): - xyz = ns._geometry.center(what='xyz') - ns._geometry = ns._geometry.translate(ns._geometry.center(what=value) - xyz) - p.add_argument(*opts('--center-of', '-co'), - choices=["mass", "mass:pbc", "xyz", "position", "cell", "mm:xyz"], - action=MoveCenterOf, - help='Move coordinates to the center of the designated choice.') + xyz = ns._geometry.center(what="xyz") + ns._geometry = ns._geometry.translate( + ns._geometry.center(what=value) - xyz + ) + + p.add_argument( + *opts("--center-of", "-co"), + choices=["mass", "mass:pbc", "xyz", "position", "cell", "mm:xyz"], + action=MoveCenterOf, + help="Move coordinates to the center of the designated choice.", + ) class MoveUnitCell(argparse.Action): def __call__(self, parser, ns, value, option_string=None): - if value in ['translate', 'tr', 't']: + if value in ["translate", "tr", "t"]: # Simple translation tmp = np.amin(ns._geometry.xyz, axis=0) ns._geometry = ns._geometry.translate(-tmp) - elif value in ['mod']: + elif value in ["mod"]: g = ns._geometry # Change all coordinates using the reciprocal cell and move to unit-cell (% 1.) - fxyz = g.fxyz % 1. + fxyz = g.fxyz % 1.0 ns._geometry.xyz[:, :] = dot(fxyz, g.cell) - p.add_argument(*opts('--unit-cell', '-uc'), choices=['translate', 'tr', 't', 'mod'], - action=MoveUnitCell, - help='Moves the coordinates into the unit-cell by translation or the mod-operator') + + p.add_argument( + *opts("--unit-cell", "-uc"), + choices=["translate", "tr", "t", "mod"], + action=MoveUnitCell, + help="Moves the coordinates into the unit-cell by translation or the mod-operator", + ) # Rotation class Rotation(argparse.Action): @@ -4947,37 +5100,55 @@ def __call__(self, parser, ns, values, option_string=None): # The rotate function expects degree ang = angle(values[0], rad=False, in_rad=False) ns._geometry = ns._geometry.rotate(ang, values[1], what="abc+xyz") - p.add_argument(*opts('--rotate', '-R'), nargs=2, metavar=('ANGLE', 'DIR'), - action=Rotation, - help='Rotate coordinates and lattice vectors around given axis (x|y|z|a|b|c). ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.') + + p.add_argument( + *opts("--rotate", "-R"), + nargs=2, + metavar=("ANGLE", "DIR"), + action=Rotation, + help='Rotate coordinates and lattice vectors around given axis (x|y|z|a|b|c). ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.', + ) if not limit_args: + class RotationX(argparse.Action): def __call__(self, parser, ns, value, option_string=None): # The rotate function expects degree ang = angle(value, rad=False, in_rad=False) ns._geometry = ns._geometry.rotate(ang, "x", what="abc+xyz") - p.add_argument(*opts('--rotate-x', '-Rx'), metavar='ANGLE', - action=RotationX, - help='Rotate coordinates and lattice vectors around x axis. ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.') + + p.add_argument( + *opts("--rotate-x", "-Rx"), + metavar="ANGLE", + action=RotationX, + help='Rotate coordinates and lattice vectors around x axis. ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.', + ) class RotationY(argparse.Action): def __call__(self, parser, ns, value, option_string=None): # The rotate function expects degree ang = angle(value, rad=False, in_rad=False) ns._geometry = ns._geometry.rotate(ang, "y", what="abc+xyz") - p.add_argument(*opts('--rotate-y', '-Ry'), metavar='ANGLE', - action=RotationY, - help='Rotate coordinates and lattice vectors around y axis. ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.') + + p.add_argument( + *opts("--rotate-y", "-Ry"), + metavar="ANGLE", + action=RotationY, + help='Rotate coordinates and lattice vectors around y axis. ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.', + ) class RotationZ(argparse.Action): def __call__(self, parser, ns, value, option_string=None): # The rotate function expects degree ang = angle(value, rad=False, in_rad=False) ns._geometry = ns._geometry.rotate(ang, "z", what="abc+xyz") - p.add_argument(*opts('--rotate-z', '-Rz'), metavar='ANGLE', - action=RotationZ, - help='Rotate coordinates and lattice vectors around z axis. ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.') + + p.add_argument( + *opts("--rotate-z", "-Rz"), + metavar="ANGLE", + action=RotationZ, + help='Rotate coordinates and lattice vectors around z axis. ANGLE defaults to be specified in degree. Prefix with "r" for input in radians.', + ) # Reduce size of geometry class ReduceSub(argparse.Action): @@ -4985,18 +5156,26 @@ def __call__(self, parser, ns, value, option_string=None): # Get atomic indices rng = lstranges(strmap(int, value)) ns._geometry = ns._geometry.sub(rng) - p.add_argument('--sub', metavar='RNG', - action=ReduceSub, - help='Retains specified atoms, can be complex ranges.') + + p.add_argument( + "--sub", + metavar="RNG", + action=ReduceSub, + help="Retains specified atoms, can be complex ranges.", + ) class ReduceRemove(argparse.Action): def __call__(self, parser, ns, value, option_string=None): # Get atomic indices rng = lstranges(strmap(int, value)) ns._geometry = ns._geometry.remove(rng) - p.add_argument('--remove', metavar='RNG', - action=ReduceRemove, - help='Removes specified atoms, can be complex ranges.') + + p.add_argument( + "--remove", + metavar="RNG", + action=ReduceRemove, + help="Removes specified atoms, can be complex ranges.", + ) # Swaps atoms class AtomSwap(argparse.Action): @@ -5005,33 +5184,52 @@ def __call__(self, parser, ns, value, option_string=None): a = lstranges(strmap(int, value[0])) b = lstranges(strmap(int, value[1])) if len(a) != len(b): - raise ValueError('swapping atoms requires equal number of LHS and RHS atomic ranges') + raise ValueError( + "swapping atoms requires equal number of LHS and RHS atomic ranges" + ) ns._geometry = ns._geometry.swap(a, b) - p.add_argument(*opts('--swap'), metavar=('A', 'B'), nargs=2, - action=AtomSwap, - help='Swaps groups of atoms (can be complex ranges). The groups must be of equal length.') + + p.add_argument( + *opts("--swap"), + metavar=("A", "B"), + nargs=2, + action=AtomSwap, + help="Swaps groups of atoms (can be complex ranges). The groups must be of equal length.", + ) # Add an atom class AtomAdd(argparse.Action): def __call__(self, parser, ns, values, option_string=None): # Create an atom from the input - g = Geometry([float(x) for x in values[0].split(',')], atoms=Atom(values[1])) + g = Geometry( + [float(x) for x in values[0].split(",")], atoms=Atom(values[1]) + ) ns._geometry = ns._geometry.add(g) - p.add_argument(*opts('--add'), nargs=2, metavar=('COORD', 'Z'), - action=AtomAdd, - help='Adds an atom, coordinate is comma separated (in Ang). Z is the atomic number.') + + p.add_argument( + *opts("--add"), + nargs=2, + metavar=("COORD", "Z"), + action=AtomAdd, + help="Adds an atom, coordinate is comma separated (in Ang). Z is the atomic number.", + ) class Translate(argparse.Action): def __call__(self, parser, ns, values, option_string=None): # Create an atom from the input - if ',' in values[0]: - xyz = [float(x) for x in values[0].split(',')] + if "," in values[0]: + xyz = [float(x) for x in values[0].split(",")] else: xyz = [float(x) for x in values[0].split()] ns._geometry = ns._geometry.translate(xyz) - p.add_argument(*opts('--translate', '-t'), nargs=1, metavar='COORD', - action=Translate, - help='Translates the coordinates via a comma separated list (in Ang).') + + p.add_argument( + *opts("--translate", "-t"), + nargs=1, + metavar="COORD", + action=Translate, + help="Translates the coordinates via a comma separated list (in Ang).", + ) # Periodicly increase the structure class PeriodRepeat(argparse.Action): @@ -5039,80 +5237,126 @@ def __call__(self, parser, ns, values, option_string=None): r = int(values[0]) d = direction(values[1]) ns._geometry = ns._geometry.repeat(r, d) - p.add_argument(*opts('--repeat', '-r'), nargs=2, metavar=('TIMES', 'DIR'), - action=PeriodRepeat, - help='Repeats the geometry in the specified direction.') + + p.add_argument( + *opts("--repeat", "-r"), + nargs=2, + metavar=("TIMES", "DIR"), + action=PeriodRepeat, + help="Repeats the geometry in the specified direction.", + ) if not limit_args: + class PeriodRepeatX(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._geometry = ns._geometry.repeat(int(value), 0) - p.add_argument(*opts('--repeat-x', '-rx'), metavar='TIMES', - action=PeriodRepeatX, - help='Repeats the geometry along the first cell vector.') + + p.add_argument( + *opts("--repeat-x", "-rx"), + metavar="TIMES", + action=PeriodRepeatX, + help="Repeats the geometry along the first cell vector.", + ) class PeriodRepeatY(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._geometry = ns._geometry.repeat(int(value), 1) - p.add_argument(*opts('--repeat-y', '-ry'), metavar='TIMES', - action=PeriodRepeatY, - help='Repeats the geometry along the second cell vector.') + + p.add_argument( + *opts("--repeat-y", "-ry"), + metavar="TIMES", + action=PeriodRepeatY, + help="Repeats the geometry along the second cell vector.", + ) class PeriodRepeatZ(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._geometry = ns._geometry.repeat(int(value), 2) - p.add_argument(*opts('--repeat-z', '-rz'), metavar='TIMES', - action=PeriodRepeatZ, - help='Repeats the geometry along the third cell vector.') + + p.add_argument( + *opts("--repeat-z", "-rz"), + metavar="TIMES", + action=PeriodRepeatZ, + help="Repeats the geometry along the third cell vector.", + ) class ReduceUnrepeat(argparse.Action): def __call__(self, parser, ns, values, option_string=None): s = int(values[0]) d = direction(values[1]) ns._geometry = ns._geometry.unrepeat(s, d) - p.add_argument(*opts('--unrepeat', '-ur'), nargs=2, metavar=('REPS', 'DIR'), - action=ReduceUnrepeat, - help='Unrepeats the geometry into `reps` parts along the unit-cell direction `dir` (opposite of --repeat).') + + p.add_argument( + *opts("--unrepeat", "-ur"), + nargs=2, + metavar=("REPS", "DIR"), + action=ReduceUnrepeat, + help="Unrepeats the geometry into `reps` parts along the unit-cell direction `dir` (opposite of --repeat).", + ) class PeriodTile(argparse.Action): def __call__(self, parser, ns, values, option_string=None): r = int(values[0]) d = direction(values[1]) ns._geometry = ns._geometry.tile(r, d) - p.add_argument(*opts('--tile'), nargs=2, metavar=('TIMES', 'DIR'), - action=PeriodTile, - help='Tiles the geometry in the specified direction.') + + p.add_argument( + *opts("--tile"), + nargs=2, + metavar=("TIMES", "DIR"), + action=PeriodTile, + help="Tiles the geometry in the specified direction.", + ) if not limit_args: + class PeriodTileX(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._geometry = ns._geometry.tile(int(value), 0) - p.add_argument(*opts('--tile-x', '-tx'), metavar='TIMES', - action=PeriodTileX, - help='Tiles the geometry along the first cell vector.') + + p.add_argument( + *opts("--tile-x", "-tx"), + metavar="TIMES", + action=PeriodTileX, + help="Tiles the geometry along the first cell vector.", + ) class PeriodTileY(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._geometry = ns._geometry.tile(int(value), 1) - p.add_argument(*opts('--tile-y', '-ty'), metavar='TIMES', - action=PeriodTileY, - help='Tiles the geometry along the second cell vector.') + + p.add_argument( + *opts("--tile-y", "-ty"), + metavar="TIMES", + action=PeriodTileY, + help="Tiles the geometry along the second cell vector.", + ) class PeriodTileZ(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._geometry = ns._geometry.tile(int(value), 2) - p.add_argument(*opts('--tile-z', '-tz'), metavar='TIMES', - action=PeriodTileZ, - help='Tiles the geometry along the third cell vector.') + + p.add_argument( + *opts("--tile-z", "-tz"), + metavar="TIMES", + action=PeriodTileZ, + help="Tiles the geometry along the third cell vector.", + ) class ReduceUntile(argparse.Action): def __call__(self, parser, ns, values, option_string=None): s = int(values[0]) d = direction(values[1]) ns._geometry = ns._geometry.untile(s, d) - p.add_argument(*opts('--untile', '--cut', '-ut'), nargs=2, metavar=('REPS', 'DIR'), - action=ReduceUntile, - help='Untiles the geometry into `reps` parts along the unit-cell direction `dir` (opposite of --tile).') + + p.add_argument( + *opts("--untile", "--cut", "-ut"), + nargs=2, + metavar=("REPS", "DIR"), + action=ReduceUntile, + help="Untiles the geometry into `reps` parts along the unit-cell direction `dir` (opposite of --tile).", + ) # append another geometry class Geometryend(argparse.Action): @@ -5125,38 +5369,48 @@ def __call__(self, parser, ns, values, option_string=None): class GeometryAppend(Geometryend): _method_pend = "append" - p.add_argument(*opts('--append'), nargs=2, metavar=('GEOM', 'DIR'), - action=GeometryAppend, - help='Appends another Geometry along direction DIR.') + + p.add_argument( + *opts("--append"), + nargs=2, + metavar=("GEOM", "DIR"), + action=GeometryAppend, + help="Appends another Geometry along direction DIR.", + ) class GeometryPrepend(Geometryend): _method_pend = "prepend" - p.add_argument(*opts('--prepend'), nargs=2, metavar=('GEOM', 'DIR'), - action=GeometryPrepend, - help='Prepends another Geometry along direction DIR.') + + p.add_argument( + *opts("--prepend"), + nargs=2, + metavar=("GEOM", "DIR"), + action=GeometryPrepend, + help="Prepends another Geometry along direction DIR.", + ) # Sort class Sort(argparse.Action): def __call__(self, parser, ns, values, option_string=None): # call geometry.sort(...) using appropriate keywords (and ordered dict) kwargs = OrderedDict() - opts = values[0].split(';') + opts = values[0].split(";") for i, opt in enumerate(opts): # Split for equal - opt = opt.split('=', 1) + opt = opt.split("=", 1) if len(opt) > 1: opt, val = opt else: opt = opt[0] val = "True" - if val.lower() in ['t', 'true']: + if val.lower() in ["t", "true"]: val = True - elif val.lower() in ['f', 'false']: + elif val.lower() in ["f", "false"]: val = False - elif opt in ['atol']: + elif opt in ["atol"]: # float values val = float(val) - elif opt == 'group': + elif opt == "group": pass else: # it must be a range/tuple @@ -5165,21 +5419,29 @@ def __call__(self, parser, ns, values, option_string=None): # we always add integers to allow users to use the same keywords on commandline kwargs[opt.strip() + str(i)] = val ns._geometry = ns._geometry.sort(**kwargs) - p.add_argument(*opts('--sort'), nargs=1, metavar='SORT', - action=Sort, - help='Semi-colon separated options for sort, please always encapsulate in quotation ["axis=0;descend;lattice=(1, 2);group=Z"].') + + p.add_argument( + *opts("--sort"), + nargs=1, + metavar="SORT", + action=Sort, + help='Semi-colon separated options for sort, please always encapsulate in quotation ["axis=0;descend;lattice=(1, 2);group=Z"].', + ) # Print some common information about the # geometry (to stdout) class PrintInfo(argparse.Action): - def __call__(self, parser, ns, no_value, option_string=None): # We fake that it has been stored... ns._stored_geometry = True print(ns._geometry) - p.add_argument(*opts('--info'), nargs=0, - action=PrintInfo, - help='Print, to stdout, some regular information about the geometry.') + + p.add_argument( + *opts("--info"), + nargs=0, + action=PrintInfo, + help="Print, to stdout, some regular information about the geometry.", + ) class Out(argparse.Action): def __call__(self, parser, ns, value, option_string=None): @@ -5189,30 +5451,40 @@ def __call__(self, parser, ns, value, option_string=None): return # If the vector, exists, we should write it kwargs = {} - if hasattr(ns, '_geom_fmt'): - kwargs['fmt'] = ns._geom_fmt - if hasattr(ns, '_vector'): - v = getattr(ns, '_vector') - vs = getattr(ns, '_vector_scale') + if hasattr(ns, "_geom_fmt"): + kwargs["fmt"] = ns._geom_fmt + if hasattr(ns, "_vector"): + v = getattr(ns, "_vector") + vs = getattr(ns, "_vector_scale") if isinstance(vs, bool): if vs: - vs = 1. / np.max(sqrt(square(v).sum(1))) + vs = 1.0 / np.max(sqrt(square(v).sum(1))) info(f"Scaling vector by: {vs}") else: - vs = 1. + vs = 1.0 # Store the vectors with the scaling - kwargs['data'] = v * vs + kwargs["data"] = v * vs ns._geometry.write(value[0], **kwargs) # Issue to the namespace that the geometry has been written, at least once. ns._stored_geometry = True - p.add_argument(*opts('--out', '-o'), nargs=1, action=Out, - help='Store the geometry (at its current invocation) to the out file.') + + p.add_argument( + *opts("--out", "-o"), + nargs=1, + action=Out, + help="Store the geometry (at its current invocation) to the out file.", + ) # If the user requests positional out arguments, we also add that. - if kwargs.get('positional_out', False): - p.add_argument('out', nargs='*', default=None, action=Out, - help='Store the geometry (at its current invocation) to the out file.') + if kwargs.get("positional_out", False): + p.add_argument( + "out", + nargs="*", + default=None, + action=Out, + help="Store the geometry (at its current invocation) to the out file.", + ) # We have now created all arguments return p, namespace @@ -5224,7 +5496,7 @@ def __call__(self, parser, ns, value, option_string=None): # Define base-class for this class GeometryNewDispatch(AbstractDispatch): - """ Base dispatcher from class passing arguments to Geometry class + """Base dispatcher from class passing arguments to Geometry class This forwards all `__call__` calls to `dispatch` """ @@ -5236,16 +5508,18 @@ def __call__(self, *args, **kwargs): # Bypass regular Geometry to be returned as is class GeometryNewGeometryDispatch(GeometryNewDispatch): def dispatch(self, geometry, copy=False): - """ Return geometry as-is (no copy), for sanitization purposes """ + """Return geometry as-is (no copy), for sanitization purposes""" if copy: return geometry.copy() return geometry + + new_dispatch.register(Geometry, GeometryNewGeometryDispatch) class GeometryNewAseDispatch(GeometryNewDispatch): def dispatch(self, aseg, **kwargs): - """ Convert an ``ase`` object into a `Geometry` """ + """Convert an ``ase`` object into a `Geometry`""" cls = self._get_class() Z = aseg.get_atomic_numbers() xyz = aseg.get_positions() @@ -5253,12 +5527,15 @@ def dispatch(self, aseg, **kwargs): nsc = [3 if pbc else 1 for pbc in aseg.pbc] lattice = Lattice(cell, nsc=nsc) return cls(xyz, atoms=Z, lattice=lattice, **kwargs) + + new_dispatch.register("ase", GeometryNewAseDispatch) # currently we can't ensure the ase Atoms type # to get it by type(). That requires ase to be importable. try: from ase import Atoms as ase_Atoms + new_dispatch.register(ase_Atoms, GeometryNewAseDispatch) # ensure we don't pollute name-space del ase_Atoms @@ -5268,8 +5545,9 @@ def dispatch(self, aseg, **kwargs): class GeometryNewpymatgenDispatch(GeometryNewDispatch): def dispatch(self, struct, **kwargs): - """ Convert a ``pymatgen`` structure/molecule object into a `Geometry` """ + """Convert a ``pymatgen`` structure/molecule object into a `Geometry`""" from pymatgen.core import Structure + cls = self._get_class(allow_instance=True) Z = [] @@ -5282,12 +5560,14 @@ def dispatch(self, struct, **kwargs): if isinstance(struct, Structure): # we also have the lattice cell = struct.lattice.matrix - nsc = [3, 3, 3] # really, this is unknown + nsc = [3, 3, 3] # really, this is unknown else: - cell = xyz.max() - xyz.min(0) + 15. + cell = xyz.max() - xyz.min(0) + 15.0 nsc = [1, 1, 1] lattice = Lattice(cell, nsc=nsc) return cls(xyz, atoms=Z, lattice=lattice, **kwargs) + + new_dispatch.register("pymatgen", GeometryNewpymatgenDispatch) # currently we can't ensure the pymatgen classes @@ -5295,6 +5575,7 @@ def dispatch(self, struct, **kwargs): try: from pymatgen.core import Molecule as pymatgen_Molecule from pymatgen.core import Structure as pymatgen_Structure + new_dispatch.register(pymatgen_Molecule, GeometryNewpymatgenDispatch) new_dispatch.register(pymatgen_Structure, GeometryNewpymatgenDispatch) # ensure we don't pollute name-space @@ -5305,24 +5586,33 @@ def dispatch(self, struct, **kwargs): class GeometryNewFileDispatch(GeometryNewDispatch): def dispatch(self, *args, **kwargs): - """ Defer the `Geometry.read` method by passing down arguments """ + """Defer the `Geometry.read` method by passing down arguments""" # can work either on class or instance return self._obj.read(*args, **kwargs) + + new_dispatch.register(str, GeometryNewFileDispatch) new_dispatch.register(Path, GeometryNewFileDispatch) # see sisl/__init__.py for new_dispatch.register(BaseSile, GeometryNewFileDispatcher) class GeometryToDispatch(AbstractDispatch): - """ Base dispatcher from class passing from Geometry class """ + """Base dispatcher from class passing from Geometry class""" class GeometryToAseDispatch(GeometryToDispatch): def dispatch(self, **kwargs): from ase import Atoms as ase_Atoms + geom = self._get_object() - return ase_Atoms(symbols=geom.atoms.Z, positions=geom.xyz.tolist(), - cell=geom.cell.tolist(), pbc=geom.nsc > 1, **kwargs) + return ase_Atoms( + symbols=geom.atoms.Z, + positions=geom.xyz.tolist(), + cell=geom.cell.tolist(), + pbc=geom.nsc > 1, + **kwargs, + ) + to_dispatch.register("ase", GeometryToAseDispatch) @@ -5347,6 +5637,7 @@ def dispatch(self, **kwargs): return Molecule(species, xyz, **kwargs) return Structure(lattice, species, xyz, coords_are_cartesian=True, **kwargs) + to_dispatch.register("pymatgen", GeometryTopymatgenDispatch) @@ -5354,6 +5645,8 @@ class GeometryToSileDispatch(GeometryToDispatch): def dispatch(self, *args, **kwargs): geom = self._get_object() return geom.write(*args, **kwargs) + + to_dispatch.register("str", GeometryToSileDispatch) to_dispatch.register("Path", GeometryToSileDispatch) # to do geom.to[Path](path) @@ -5364,6 +5657,7 @@ def dispatch(self, *args, **kwargs): class GeometryToDataframeDispatch(GeometryToDispatch): def dispatch(self, *args, **kwargs): import pandas as pd + geom = self._get_object() # Now create data-frame @@ -5390,6 +5684,8 @@ def dispatch(self, *args, **kwargs): data["norbitals"] = atoms.orbitals return pd.DataFrame(data) + + to_dispatch.register("dataframe", GeometryToDataframeDispatch) # Remove references @@ -5398,7 +5694,7 @@ def dispatch(self, *args, **kwargs): @set_module("sisl") def sgeom(geometry=None, argv=None, ret_geometry=False): - """ Main script for sgeom. + """Main script for sgeom. This routine may be called with `argv` and/or a `Sile` which is the geometry at hand. @@ -5453,20 +5749,22 @@ def sgeom(geometry=None, argv=None, ret_geometry=False): if argv is not None: if len(argv) == 0: - argv = ['--help'] + argv = ["--help"] elif len(sys.argv) == 1: # no arguments # fake a help - argv = ['--help'] + argv = ["--help"] else: argv = sys.argv[1:] # Ensure that the arguments have pre-pended spaces argv = cmd.argv_negative_fix(argv) - p = argparse.ArgumentParser(exe, - formatter_class=argparse.RawDescriptionHelpFormatter, - description=description) + p = argparse.ArgumentParser( + exe, + formatter_class=argparse.RawDescriptionHelpFormatter, + description=description, + ) # Add default sisl version stuff cmd.add_sisl_version_cite_arg(p) @@ -5475,6 +5773,7 @@ def sgeom(geometry=None, argv=None, ret_geometry=False): stdout_geom = True if geometry is None: from os.path import isfile + argv, input_file = cmd.collect_input(argv) if input_file is None: @@ -5488,6 +5787,7 @@ def sgeom(geometry=None, argv=None, ret_geometry=False): geometry = get_sile(input_file).read_geometry() else: from .messages import info + info(f"Cannot find file '{input_file}'!") geometry = Geometry stdout_geom = False @@ -5508,15 +5808,15 @@ def sgeom(geometry=None, argv=None, ret_geometry=False): # and we will sort out if the input options # is only a help option. try: - if not hasattr(ns, '_input_file'): - setattr(ns, '_input_file', input_file) + if not hasattr(ns, "_input_file"): + setattr(ns, "_input_file", input_file) except Exception: pass # Now try and figure out the actual arguments - p, ns, argv = cmd.collect_arguments(argv, input=False, - argumentparser=p, - namespace=ns) + p, ns, argv = cmd.collect_arguments( + argv, input=False, argumentparser=p, namespace=ns + ) # We are good to go!!! args = p.parse_args(argv, namespace=ns) @@ -5525,15 +5825,18 @@ def sgeom(geometry=None, argv=None, ret_geometry=False): if stdout_geom and not args._stored_geometry: # We should write out the information to the stdout # This is merely for testing purposes and may not be used for anything. - print('Cell:') + print("Cell:") for i in (0, 1, 2): - print(' {:10.6f} {:10.6f} {:10.6f}'.format(*g.cell[i, :])) - print('Lattice:') - print(' {:d} {:d} {:d}'.format(*g.nsc)) - print(' {:>10s} {:>10s} {:>10s} {:>3s}'.format('x', 'y', 'z', 'Z')) + print(" {:10.6f} {:10.6f} {:10.6f}".format(*g.cell[i, :])) + print("Lattice:") + print(" {:d} {:d} {:d}".format(*g.nsc)) + print(" {:>10s} {:>10s} {:>10s} {:>3s}".format("x", "y", "z", "Z")) for ia in g: - print(' {1:10.6f} {2:10.6f} {3:10.6f} {0:3d}'.format(g.atoms[ia].Z, - *g.xyz[ia, :])) + print( + " {1:10.6f} {2:10.6f} {3:10.6f} {0:3d}".format( + g.atoms[ia].Z, *g.xyz[ia, :] + ) + ) if ret_geometry: return g diff --git a/src/sisl/grid.py b/src/sisl/grid.py index 662620b45d..f1983198bd 100644 --- a/src/sisl/grid.py +++ b/src/sisl/grid.py @@ -29,7 +29,7 @@ ) from .utils.mathematics import fnorm -__all__ = ['Grid', 'sgrid'] +__all__ = ["Grid", "sgrid"] _log = logging.getLogger("sisl") _log.info(f"adding logger: {__name__}") @@ -38,7 +38,7 @@ @set_module("sisl") class Grid(LatticeChild): - """ Real-space grid information with associated geometry. + """Real-space grid information with associated geometry. This grid object handles cell vectors and divisions of said grid. @@ -89,14 +89,19 @@ class Grid(LatticeChild): #: Constant for defining an open boundary condition OPEN = BoundaryCondition.OPEN - @deprecate_argument("sc", "lattice", - "argument sc has been deprecated in favor of lattice, please update your code.", - "0.15.0") - @deprecate_argument("bc", None, - "argument bc has been deprecated (removed) in favor of the boundary conditions in Lattice, please update your code.", - "0.15.0") + @deprecate_argument( + "sc", + "lattice", + "argument sc has been deprecated in favor of lattice, please update your code.", + "0.15.0", + ) + @deprecate_argument( + "bc", + None, + "argument bc has been deprecated (removed) in favor of the boundary conditions in Lattice, please update your code.", + "0.15.0", + ) def __init__(self, shape, bc=None, lattice=None, dtype=None, geometry=None): - self.set_lattice(None) # Create the atomic structure in the grid, if possible @@ -107,7 +112,7 @@ def __init__(self, shape, bc=None, lattice=None, dtype=None, geometry=None): self.set_lattice(lattice) if isinstance(shape, Real): - d = (self.cell ** 2).sum(1) ** 0.5 + d = (self.cell**2).sum(1) ** 0.5 shape = list(map(int, np.rint(d / shape))) # Create the grid @@ -117,28 +122,37 @@ def __init__(self, shape, bc=None, lattice=None, dtype=None, geometry=None): if bc is not None: self.lattice.set_boundary_condition(bc) - @deprecation("Grid.set_bc is deprecated since boundary conditions are moved to Lattice (see github issue #626", "0.15.0") + @deprecation( + "Grid.set_bc is deprecated since boundary conditions are moved to Lattice (see github issue #626", + "0.15.0", + ) def set_bc(self, bc): self.lattice.set_boundary_condition(bc) - @deprecation("Grid.set_boundary is deprecated since boundary conditions are moved to Lattice (see github issue #626", "0.15.0") + @deprecation( + "Grid.set_boundary is deprecated since boundary conditions are moved to Lattice (see github issue #626", + "0.15.0", + ) def set_boundary(self, bc): self.lattice.set_boundary_condition(bc) - @deprecation("Grid.set_boundary_condition is deprecated since boundary conditions are moved to Lattice (see github issue #626", "0.15.0") + @deprecation( + "Grid.set_boundary_condition is deprecated since boundary conditions are moved to Lattice (see github issue #626", + "0.15.0", + ) def set_boundary_condition(self, bc): self.lattice.set_boundary_condition(bc) def __getitem__(self, key): - """ Grid value at `key` """ + """Grid value at `key`""" return self.grid[key] def __setitem__(self, key, val): - """ Updates the grid contained """ + """Updates the grid contained""" self.grid[key] = val def _is_commensurate(self): - """Determine whether the contained geometry and lattice are commensurate """ + """Determine whether the contained geometry and lattice are commensurate""" if self.geometry is None: return True # ideally this should be checked that they are integer equivalent @@ -152,8 +166,8 @@ def _is_commensurate(self): return False return np.all(abs(reps - np.round(reps)) < 1e-5) - def set_geometry(self, geometry, also_lattice: bool=True): - """ Sets the `Geometry` for the grid. + def set_geometry(self, geometry, also_lattice: bool = True): + """Sets the `Geometry` for the grid. Setting the `Geometry` for the grid is a possibility to attach atoms to the grid. @@ -179,7 +193,7 @@ def set_geometry(self, geometry, also_lattice: bool=True): self.set_lattice(geometry.lattice) def fill(self, val): - """ Fill the grid with this value + """Fill the grid with this value Parameters ---------- @@ -189,7 +203,7 @@ def fill(self, val): self.grid.fill(val) def interp(self, shape, order=1, mode="wrap", **kwargs): - """ Interpolate grid values to a new grid of a different shape + """Interpolate grid values to a new grid of a different shape It uses the `scipy.ndimage.zoom`, which creates a finer or more spaced grid using spline interpolation. @@ -218,7 +232,7 @@ def interp(self, shape, order=1, mode="wrap", **kwargs): if isinstance(order, str): method = order if method is not None: - order = {'linear': 1}.get(method, 3) + order = {"linear": 1}.get(method, 3) # And now we do the actual interpolation # Calculate the zoom_factors @@ -235,10 +249,10 @@ def isosurface(self, level, step_size=1, **kwargs): Parameters ---------- level: float - contour value to search for isosurfaces in the grid. + contour value to search for isosurfaces in the grid. If not given or None, the average of the min and max of the grid is used. step_size: int, optional - step size in voxels. Larger steps yield faster but coarser results. + step size in voxels. Larger steps yield faster but coarser results. The result will always be topologically correct though. **kwargs: optional arguments passed directly to `skimage.measure.marching_cubes` @@ -250,14 +264,14 @@ def isosurface(self, level, step_size=1, **kwargs): Verts. Spatial coordinates for V unique mesh vertices. numpy array of shape (n_faces, 3) - Faces. Define triangular faces via referencing vertex indices from verts. + Faces. Define triangular faces via referencing vertex indices from verts. This algorithm specifically outputs triangles, so each face has exactly three indices. numpy array of shape (V, 3) Normals. The normal direction at each vertex, as calculated from the data. numpy array of shape (V, 3) - Values. Gives a measure for the maximum value of the data in the local region near each vertex. + Values. Gives a measure for the maximum value of the data in the local region near each vertex. This can be used by visualization tools to apply a colormap to the mesh. See Also @@ -283,7 +297,7 @@ def isosurface(self, level, step_size=1, **kwargs): return (verts, *returns) def smooth(self, r=0.7, method="gaussian", mode="wrap", **kwargs): - """ Make a smoother grid by applying a filter. + """Make a smoother grid by applying a filter. Parameters ----------- @@ -313,16 +327,16 @@ def smooth(self, r=0.7, method="gaussian", mode="wrap", **kwargs): # Update the kwargs accordingly if method == "gaussian": - kwargs['sigma'] = pixels_r + kwargs["sigma"] = pixels_r elif method == "uniform": - kwargs['size'] = pixels_r * 2 + kwargs["size"] = pixels_r * 2 # This should raise an import error if the method does not exist func = import_attr(f"scipy.ndimage.{method}_filter") return self.apply(func, mode=mode, **kwargs) def apply(self, function_, *args, **kwargs): - """ Applies a function to the grid and returns a new grid. + """Applies a function to the grid and returns a new grid. You can also apply a function that does not return a grid (maybe you want to do some measurement). In that case, you will get the result instead of a `Grid`. @@ -355,55 +369,57 @@ def apply(self, function_, *args, **kwargs): @property def size(self): - """ Total number of elements in the grid """ + """Total number of elements in the grid""" return np.prod(self.grid.shape) @property def shape(self): - r""" Grid shape in :math:`x`, :math:`y`, :math:`z` directions """ + r"""Grid shape in :math:`x`, :math:`y`, :math:`z` directions""" return self.grid.shape @property def dtype(self): - """ Data-type used in grid """ + """Data-type used in grid""" return self.grid.dtype @property def dkind(self): - """ The data-type of the grid (in str) """ + """The data-type of the grid (in str)""" return np.dtype(self.grid.dtype).kind def set_grid(self, shape, dtype=None): - """ Create the internal grid of certain size. """ + """Create the internal grid of certain size.""" shape = _a.asarrayi(shape).ravel() if dtype is None: dtype = np.float64 if shape.size != 3: - raise ValueError(f"{self.__class__.__name__}.set_grid requires shape to be of length 3") + raise ValueError( + f"{self.__class__.__name__}.set_grid requires shape to be of length 3" + ) self.grid = np.zeros(shape, dtype=dtype) def _sc_geometry_dict(self): - """ Internal routine for copying the Lattice and Geometry """ + """Internal routine for copying the Lattice and Geometry""" d = dict() - d['lattice'] = self.lattice.copy() + d["lattice"] = self.lattice.copy() if not self.geometry is None: - d['geometry'] = self.geometry.copy() + d["geometry"] = self.geometry.copy() return d def copy(self, dtype=None): - r""" Copy the object, possibly changing the data-type """ + r"""Copy the object, possibly changing the data-type""" d = self._sc_geometry_dict() if dtype is None: - d['dtype'] = self.dtype + d["dtype"] = self.dtype else: - d['dtype'] = dtype + d["dtype"] = dtype grid = self.__class__([1] * 3, **d) # This also ensures the shape is copied! - grid.grid = self.grid.astype(dtype=d['dtype']) + grid.grid = self.grid.astype(dtype=d["dtype"]) return grid def swapaxes(self, a, b): - """ Swap two axes in the grid (also swaps axes in the lattice) + """Swap two axes in the grid (also swaps axes in the lattice) If ``swapaxes(0,1)`` it returns the 0 in the 1 values. @@ -418,22 +434,22 @@ def swapaxes(self, a, b): idx[a] = b s = np.copy(self.shape) d = self._sc_geometry_dict() - d['lattice'] = d['lattice'].swapaxes(a, b) - d['dtype'] = self.dtype + d["lattice"] = d["lattice"].swapaxes(a, b) + d["dtype"] = self.dtype grid = self.__class__(s[idx], **d) # We need to force the C-order or we loose the contiguity - grid.grid = np.copy(np.swapaxes(self.grid, a, b), order='C') + grid.grid = np.copy(np.swapaxes(self.grid, a, b), order="C") return grid @property def dcell(self): - """ Voxel cell size """ + """Voxel cell size""" # Calculate the grid-distribution return self.cell / _a.asarrayi(self.shape).reshape(3, 1) @property def dvolume(self): - """ Volume of the grid voxel elements """ + """Volume of the grid voxel elements""" return self.lattice.volume / self.size def _copy_sub(self, n, axis, scale_geometry=False): @@ -444,7 +460,7 @@ def _copy_sub(self, n, axis, scale_geometry=False): cell[axis, :] = (cell[axis, :] / shape[axis]) * n shape[axis] = n if n < 1: - raise ValueError('You cannot retain no indices.') + raise ValueError("You cannot retain no indices.") grid = self.__class__(shape, dtype=self.dtype, **self._sc_geometry_dict()) # Update cell shape (the cell is smaller now) grid.set_lattice(cell) @@ -458,10 +474,10 @@ def _copy_sub(self, n, axis, scale_geometry=False): return grid def cross_section(self, idx, axis): - """ Takes a cross-section of the grid along axis `axis` + """Takes a cross-section of the grid along axis `axis` Remark: This API entry might change to handle arbitrary - cuts via rotation of the axis """ + cuts via rotation of the axis""" idx = _a.asarrayi(idx).ravel() grid = self._copy_sub(1, axis) @@ -477,7 +493,7 @@ def cross_section(self, idx, axis): return grid def sum(self, axis): - """ Sum grid values along axis `axis`. + """Sum grid values along axis `axis`. Parameters ---------- @@ -490,7 +506,7 @@ def sum(self, axis): return grid def average(self, axis, weights=None): - """ Average grid values along direction `axis`. + """Average grid values along direction `axis`. Parameters ---------- @@ -517,7 +533,9 @@ def average(self, axis, weights=None): elif axis == 2: grid.grid[:, :, 0] = np.average(self.grid, axis=axis, weights=weights) else: - raise ValueError(f"{self.__class__.__name__}.average requires axis to be in [0, 1, 2]") + raise ValueError( + f"{self.__class__.__name__}.average requires axis to be in [0, 1, 2]" + ) return grid @@ -525,7 +543,7 @@ def average(self, axis, weights=None): mean = average def remove_part(self, idx, axis, above): - """ Removes parts of the grid via above/below designations. + """Removes parts of the grid via above/below designations. Works exactly opposite to `sub_part` @@ -546,7 +564,7 @@ def remove_part(self, idx, axis, above): return self.sub_part(idx, axis, not above) def sub_part(self, idx, axis, above): - """ Retains parts of the grid via above/below designations. + """Retains parts of the grid via above/below designations. Works exactly opposite to `remove_part` @@ -571,7 +589,7 @@ def sub_part(self, idx, axis, above): return self.sub(sub, axis) def sub(self, idx, axis): - """ Retains certain indices from a specified axis. + """Retains certain indices from a specified axis. Works exactly opposite to `remove`. @@ -610,7 +628,7 @@ def sub(self, idx, axis): return grid def remove(self, idx, axis): - """ Removes certain indices from a specified axis. + """Removes certain indices from a specified axis. Works exactly opposite to `sub`. @@ -625,7 +643,7 @@ def remove(self, idx, axis): return self.sub(ret_idx, axis) def tile(self, reps, axis): - """ Tile grid to create a bigger one + """Tile grid to create a bigger one The atomic indices for the base Geometry will be retained. @@ -645,8 +663,10 @@ def tile(self, reps, axis): Geometry.tile : equivalent method for Geometry class """ if not self._is_commensurate(): - raise SislError(f"{self.__class__.__name__} cannot tile the grid since the contained" - " Geometry and Lattice are not commensurate.") + raise SislError( + f"{self.__class__.__name__} cannot tile the grid since the contained" + " Geometry and Lattice are not commensurate." + ) grid = self.copy() grid.grid = None reps_all = [1, 1, 1] @@ -659,7 +679,7 @@ def tile(self, reps, axis): return grid def index2xyz(self, index): - """ Real-space coordinates of indices related to the grid + """Real-space coordinates of indices related to the grid Parameters ---------- @@ -674,7 +694,7 @@ def index2xyz(self, index): return asarray(index).dot(self.dcell) def index_fold(self, index, unique=True): - """ Converts indices from *any* placement to only exist in the "primary" grid + """Converts indices from *any* placement to only exist in the "primary" grid Examples -------- @@ -702,7 +722,9 @@ def index_fold(self, index, unique=True): # Convert to internal if unique: - index = np.unique(index.reshape(-1, 3) % _a.asarrayi(self.shape)[None, :], axis=0) + index = np.unique( + index.reshape(-1, 3) % _a.asarrayi(self.shape)[None, :], axis=0 + ) else: index = index.reshape(-1, 3) % _a.asarrayi(self.shape)[None, :] @@ -711,7 +733,7 @@ def index_fold(self, index, unique=True): return index def index_truncate(self, index): - """ Remove indices from *outside* the grid to only retain indices in the "primary" grid + """Remove indices from *outside* the grid to only retain indices in the "primary" grid Examples -------- @@ -746,7 +768,7 @@ def index_truncate(self, index): return index def _index_shape(self, shape): - """ Internal routine for shape-indices """ + """Internal routine for shape-indices""" # First grab the sphere, subsequent indices will be reduced # by the actual shape cuboid = shape.to.Cuboid() @@ -789,7 +811,7 @@ def _index_shape(self, shape): return i def _index_shape_cuboid(self, cuboid): - """ Internal routine for cuboid shape-indices """ + """Internal routine for cuboid shape-indices""" # Construct all points on the outer rim of the cuboids min_d = fnorm(self.dcell).min() @@ -820,11 +842,11 @@ def plane(v1, v2): rxyz[1, :, :] = UR i = 0 - rxyz[:, i:i + sa * sb, :] += plane(a, b) + rxyz[:, i : i + sa * sb, :] += plane(a, b) i += sa * sb - rxyz[:, i:i + sa * sc, :] += plane(a, c) + rxyz[:, i : i + sa * sc, :] += plane(a, c) i += sa * sc - rxyz[:, i:i + sb * sc, :] += plane(b, c) + rxyz[:, i : i + sb * sc, :] += plane(b, c) del a, b, c, sa, sb, sc rxyz.shape = (-1, 3) @@ -832,7 +854,7 @@ def plane(v1, v2): return self.index(rxyz) def _index_shape_ellipsoid(self, ellipsoid): - """ Internal routine for ellipsoid shape-indices """ + """Internal routine for ellipsoid shape-indices""" # Figure out the points on the ellipsoid rad1 = pi / 180 theta, phi = ogrid[-pi:pi:rad1, 0:pi:rad1] @@ -849,7 +871,7 @@ def _index_shape_ellipsoid(self, ellipsoid): return self.index(rxyz) def index(self, coord, axis=None): - """ Find the grid index for a given coordinate (possibly only along a given lattice vector `axis`) + """Find the grid index for a given coordinate (possibly only along a given lattice vector `axis`) Parameters ---------- @@ -867,11 +889,13 @@ def index(self, coord, axis=None): return self._index_shape(coord) coord = _a.asarrayd(coord) - if coord.size == 1: # float + if coord.size == 1: # float if axis is None: - raise ValueError(f"{self.__class__.__name__}.index requires the " - "coordinate to be 3 values when an axis has not " - "been specified.") + raise ValueError( + f"{self.__class__.__name__}.index requires the " + "coordinate to be 3 values when an axis has not " + "been specified." + ) c = (self.dcell[axis, :] ** 2).sum() ** 0.5 return int(floor(coord / c)) @@ -889,31 +913,41 @@ def index(self, coord, axis=None): # each lattice vector) if axis is None: if ndim == 1: - return floor(dot(icell, coord.reshape(-1, 3).T) * shape).reshape(3).astype(int32, copy=False) + return ( + floor(dot(icell, coord.reshape(-1, 3).T) * shape) + .reshape(3) + .astype(int32, copy=False) + ) else: - return floor(dot(icell, coord.reshape(-1, 3).T) * shape).T.astype(int32, copy=False) + return floor(dot(icell, coord.reshape(-1, 3).T) * shape).T.astype( + int32, copy=False + ) if ndim == 1: - return floor(dot(icell[axis, :], coord.reshape(-1, 3).T) * shape[axis]).astype(int32, copy=False)[0] + return floor( + dot(icell[axis, :], coord.reshape(-1, 3).T) * shape[axis] + ).astype(int32, copy=False)[0] else: - return floor(dot(icell[axis, :], coord.reshape(-1, 3).T) * shape[axis]).T.astype(int32, copy=False) + return floor( + dot(icell[axis, :], coord.reshape(-1, 3).T) * shape[axis] + ).T.astype(int32, copy=False) def append(self, other, axis): - """ Appends other `Grid` to this grid along axis """ + """Appends other `Grid` to this grid along axis""" shape = list(self.shape) shape[axis] += other.shape[axis] d = self._sc_geometry_dict() - if 'geometry' in d: + if "geometry" in d: if not other.geometry is None: - d['geometry'] = d['geometry'].append(other.geometry, axis) + d["geometry"] = d["geometry"].append(other.geometry, axis) else: - d['geometry'] = other.geometry - d['lattice'] = self.lattice.append(other.lattice, axis) - d['dtype'] = self.dtype + d["geometry"] = other.geometry + d["lattice"] = self.lattice.append(other.lattice, axis) + d["dtype"] = self.dtype return self.__class__(shape, **d) @staticmethod def read(sile, *args, **kwargs): - """ Reads grid from the `Sile` using `read_grid` + """Reads grid from the `Sile` using `read_grid` Parameters ---------- @@ -925,21 +959,22 @@ def read(sile, *args, **kwargs): # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_grid(*args, **kwargs) else: sile = str(sile) sile, spec = str_spec(sile) if spec is not None: - if ',' in spec: - kwargs['index'] = list(map(float, spec.split(','))) + if "," in spec: + kwargs["index"] = list(map(float, spec.split(","))) else: - kwargs['index'] = int(spec) - with get_sile(sile, mode='r') as fh: + kwargs["index"] = int(spec) + with get_sile(sile, mode="r") as fh: return fh.read_grid(*args, **kwargs) def write(self, sile, *args, **kwargs) -> None: - """ Writes grid to the `Sile` using `write_grid` + """Writes grid to the `Sile` using `write_grid` Parameters ---------- @@ -951,26 +986,31 @@ def write(self, sile, *args, **kwargs) -> None: # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): sile.write_grid(self, *args, **kwargs) else: - with get_sile(sile, mode='w') as fh: + with get_sile(sile, mode="w") as fh: fh.write_grid(self, *args, **kwargs) def __str__(self): - """ String of object """ - s = '{name}{{kind: {kind}, shape: [{shape[0]} {shape[1]} {shape[2]}],\n'.format(kind=self.dkind, shape=self.shape, name=self.__class__.__name__) + """String of object""" + s = "{name}{{kind: {kind}, shape: [{shape[0]} {shape[1]} {shape[2]}],\n".format( + kind=self.dkind, shape=self.shape, name=self.__class__.__name__ + ) if self._is_commensurate() and self.geometry is not None: - l = np.round(self.lattice.length / self.geometry.lattice.length).astype(np.int32) + l = np.round(self.lattice.length / self.geometry.lattice.length).astype( + np.int32 + ) s += f"commensurate: [{l[0]} {l[1]} {l[2]}]" else: - s += '{}'.format(str(self.lattice).replace('\n', '\n ')) + s += "{}".format(str(self.lattice).replace("\n", "\n ")) if not self.geometry is None: - s += ',\n {}'.format(str(self.geometry).replace('\n', '\n ')) + s += ",\n {}".format(str(self.geometry).replace("\n", "\n ")) return f"{s}\n}}" def _check_compatibility(self, other, msg): - """ Internal check for asserting two grids are commensurable """ + """Internal check for asserting two grids are commensurable""" if self == other: return True s1 = str(self) @@ -978,30 +1018,30 @@ def _check_compatibility(self, other, msg): raise ValueError(f"Grids are not compatible, {s1}-{s2}. {msg}") def _compatible_copy(self, other, *args, **kwargs): - """ Internally used copy function that also checks whether the two grids are compatible """ + """Internally used copy function that also checks whether the two grids are compatible""" if isinstance(other, Grid): self._check_compatibility(other, *args, **kwargs) return self.copy() def __eq__(self, other): - """ Whether two grids are commensurable (no value checks, only grid shape) + """Whether two grids are commensurable (no value checks, only grid shape) - There will be no check of the values _on_ the grid. """ + There will be no check of the values _on_ the grid.""" return self.shape == other.shape def __ne__(self, other): - """ Whether two grids are incommensurable (no value checks, only grid shape) """ + """Whether two grids are incommensurable (no value checks, only grid shape)""" return not (self == other) def __abs__(self): - r""" Take the absolute value of the grid :math:`|grid|` """ + r"""Take the absolute value of the grid :math:`|grid|`""" dtype = dtype_complex_to_real(self.dtype) a = self.copy() a.grid = np.absolute(self.grid).astype(dtype, copy=False) return a def __add__(self, other): - """ Add two grid values (or add a single value to all grid values) + """Add two grid values (or add a single value to all grid values) Raises ------ @@ -1009,7 +1049,7 @@ def __add__(self, other): if the grids are not compatible (different shapes) """ if isinstance(other, Grid): - grid = self._compatible_copy(other, 'they cannot be added') + grid = self._compatible_copy(other, "they cannot be added") grid.grid = self.grid + other.grid else: grid = self.copy() @@ -1017,7 +1057,7 @@ def __add__(self, other): return grid def __iadd__(self, other): - """ Add, in-place, values from another grid + """Add, in-place, values from another grid Raises ------ @@ -1025,14 +1065,14 @@ def __iadd__(self, other): if the grids are not compatible (different shapes) """ if isinstance(other, Grid): - self._check_compatibility(other, 'they cannot be added') + self._check_compatibility(other, "they cannot be added") self.grid += other.grid else: self.grid += other return self def __sub__(self, other): - """ Subtract two grid values (or subtract a single value from all grid values) + """Subtract two grid values (or subtract a single value from all grid values) Raises ------ @@ -1040,7 +1080,7 @@ def __sub__(self, other): if the grids are not compatible (different shapes) """ if isinstance(other, Grid): - grid = self._compatible_copy(other, 'they cannot be subtracted') + grid = self._compatible_copy(other, "they cannot be subtracted") np.subtract(self.grid, other.grid, out=grid.grid) else: grid = self.copy() @@ -1048,7 +1088,7 @@ def __sub__(self, other): return grid def __isub__(self, other): - """ Subtract, in-place, values from another grid + """Subtract, in-place, values from another grid Raises ------ @@ -1056,7 +1096,7 @@ def __isub__(self, other): if the grids are not compatible (different shapes) """ if isinstance(other, Grid): - self._check_compatibility(other, 'they cannot be subtracted') + self._check_compatibility(other, "they cannot be subtracted") self.grid -= other.grid else: self.grid -= other @@ -1070,7 +1110,7 @@ def __idiv__(self, other): def __truediv__(self, other): if isinstance(other, Grid): - grid = self._compatible_copy(other, 'they cannot be divided') + grid = self._compatible_copy(other, "they cannot be divided") np.divide(self.grid, other.grid, out=grid.grid) else: grid = self.copy() @@ -1079,7 +1119,7 @@ def __truediv__(self, other): def __itruediv__(self, other): if isinstance(other, Grid): - self._check_compatibility(other, 'they cannot be divided') + self._check_compatibility(other, "they cannot be divided") self.grid /= other.grid else: self.grid /= other @@ -1087,7 +1127,7 @@ def __itruediv__(self, other): def __mul__(self, other): if isinstance(other, Grid): - grid = self._compatible_copy(other, 'they cannot be multiplied') + grid = self._compatible_copy(other, "they cannot be multiplied") np.multiply(self.grid, other.grid, out=grid.grid) else: grid = self.copy() @@ -1096,7 +1136,7 @@ def __mul__(self, other): def __imul__(self, other): if isinstance(other, Grid): - self._check_compatibility(other, 'they cannot be multiplied') + self._check_compatibility(other, "they cannot be multiplied") self.grid *= other.grid else: self.grid *= other @@ -1106,7 +1146,7 @@ def __imul__(self, other): # work-through case with other programs. @classmethod def mgrid(cls, *slices): - """ Return a list of indices corresponding to the slices + """Return a list of indices corresponding to the slices The returned values are equivalent to `numpy.mgrid` but they are returned in a (:, 3) array. @@ -1134,7 +1174,7 @@ def mgrid(cls, *slices): return indices def pyamg_index(self, index): - r""" Calculate `pyamg` matrix indices from a list of grid indices + r"""Calculate `pyamg` matrix indices from a list of grid indices Parameters ---------- @@ -1159,14 +1199,16 @@ def pyamg_index(self, index): index = _a.asarrayi(index).reshape(-1, 3) grid = _a.arrayi(self.shape[:]) if np.any(index < 0) or np.any(index >= grid.reshape(1, 3)): - raise ValueError(f"{self.__class__.__name__}.pyamg_index erroneous values for grid indices") + raise ValueError( + f"{self.__class__.__name__}.pyamg_index erroneous values for grid indices" + ) # Skipping factor per element cp = _a.arrayi([[grid[1] * grid[2], grid[2], 1]]) return (cp * index).sum(1) @classmethod def pyamg_source(cls, b, pyamg_indices, value): - r""" Fix the source term to `value`. + r"""Fix the source term to `value`. Parameters ---------- @@ -1179,7 +1221,7 @@ def pyamg_source(cls, b, pyamg_indices, value): b[pyamg_indices] = value def pyamg_fix(self, A, b, pyamg_indices, value): - r""" Fix values for the stencil to `value`. + r"""Fix values for the stencil to `value`. Parameters ---------- @@ -1194,17 +1236,19 @@ def pyamg_fix(self, A, b, pyamg_indices, value): the value of the grid to fix the value at """ if not A.format in ("csc", "csr"): - raise ValueError(f"{self.__class__.__name__}.pyamg_fix only works for csr/csc sparse matrices") + raise ValueError( + f"{self.__class__.__name__}.pyamg_fix only works for csr/csc sparse matrices" + ) # Clean all couplings between the respective indices and all other data - s = _a.array_arange(A.indptr[pyamg_indices], A.indptr[pyamg_indices+1]) - A.data[s] = 0. + s = _a.array_arange(A.indptr[pyamg_indices], A.indptr[pyamg_indices + 1]) + A.data[s] = 0.0 # clean-up del s # Specify that these indices are not to be tampered with d = np.zeros(A.shape[0], dtype=A.dtype) - d[pyamg_indices] = 1. + d[pyamg_indices] = 1.0 # BUG in scipy, sparse matrix += does not do in-place operations # hence we need to overwrite the `A` matrix afterward AA = A + sp_diags(d, format=A.format) @@ -1221,7 +1265,7 @@ def pyamg_fix(self, A, b, pyamg_indices, value): @wrap_filterwarnings("ignore", category=SparseEfficiencyWarning) def pyamg_boundary_condition(self, A, b, bc=None): - r""" Attach boundary conditions to the `pyamg` grid-matrix `A` with default boundary conditions as specified for this `Grid` + r"""Attach boundary conditions to the `pyamg` grid-matrix `A` with default boundary conditions as specified for this `Grid` Parameters ---------- @@ -1234,18 +1278,21 @@ def pyamg_boundary_condition(self, A, b, bc=None): Default to the grid's boundary conditions, else `bc` *must* be a list of elements with elements corresponding to `Grid.PERIODIC`/`Grid.NEUMANN`... """ + def Neumann(idx_bc, idx_p1): # Set all boundary equations to 0 - s = _a.array_arange(A.indptr[idx_bc], A.indptr[idx_bc+1]) + s = _a.array_arange(A.indptr[idx_bc], A.indptr[idx_bc + 1]) A.data[s] = 0 # force the boundary cells to equal the neighbouring cell A[idx_bc, idx_bc] = 1 A[idx_bc, idx_p1] = -1 A.eliminate_zeros() - b[idx_bc] = 0. + b[idx_bc] = 0.0 + def Dirichlet(idx): # Default pyamg Poisson matrix has Dirichlet BC - b[idx] = 0. + b[idx] = 0.0 + def Periodic(idx1, idx2): A[idx1, idx2] = -1 A[idx2, idx1] = -1 @@ -1263,15 +1310,16 @@ def sl2idx(sl): # LOWER BOUNDARY bci = self.boundary_condition[i] new_sl[i] = slice(0, 1) - idx1 = sl2idx(new_sl) # lower + idx1 = sl2idx(new_sl) # lower bc = bci[0] - if bci[0] == self.PERIODIC or \ - bci[1] == self.PERIODIC: + if bci[0] == self.PERIODIC or bci[1] == self.PERIODIC: if bci[0] != bci[1]: - raise ValueError(f"{self.__class__.__name__}.pyamg_boundary_condition found a periodic and non-periodic direction in the same direction!") - new_sl[i] = slice(self.shape[i]-1, self.shape[i]) - idx2 = sl2idx(new_sl) # upper + raise ValueError( + f"{self.__class__.__name__}.pyamg_boundary_condition found a periodic and non-periodic direction in the same direction!" + ) + new_sl[i] = slice(self.shape[i] - 1, self.shape[i]) + idx2 = sl2idx(new_sl) # upper Periodic(idx1, idx2) del idx2 continue @@ -1279,7 +1327,7 @@ def sl2idx(sl): if bc == self.NEUMANN: # Retrieve next index new_sl[i] = slice(1, 2) - idx2 = sl2idx(new_sl) # lower + 1 + idx2 = sl2idx(new_sl) # lower + 1 Neumann(idx1, idx2) del idx2 elif bc == self.DIRICHLET: @@ -1287,13 +1335,13 @@ def sl2idx(sl): # UPPER BOUNDARY bc = bci[1] - new_sl[i] = slice(self.shape[i]-1, self.shape[i]) - idx1 = sl2idx(new_sl) # upper + new_sl[i] = slice(self.shape[i] - 1, self.shape[i]) + idx1 = sl2idx(new_sl) # upper if bc == self.NEUMANN: # Retrieve next index - new_sl[i] = slice(self.shape[i]-2, self.shape[i]-1) - idx2 = sl2idx(new_sl) # upper - 1 + new_sl[i] = slice(self.shape[i] - 2, self.shape[i] - 1) + idx2 = sl2idx(new_sl) # upper - 1 Neumann(idx1, idx2) del idx2 elif bc == self.DIRICHLET: @@ -1302,7 +1350,7 @@ def sl2idx(sl): A.eliminate_zeros() def topyamg(self, dtype=None): - r""" Create a `pyamg` stencil matrix to be used in pyamg + r"""Create a `pyamg` stencil matrix to be used in pyamg This allows retrieving the grid matrix equivalent of the real-space grid. Subsequently the returned matrix may be used in pyamg for solutions etc. @@ -1343,10 +1391,11 @@ def topyamg(self, dtype=None): pyamg_boundary_condition : setup the sparse matrix ``A`` to given boundary conditions (called in this routine) """ from pyamg.gallery import poisson + if dtype is None: dtype = self.dtype # Initially create the CSR matrix - A = poisson(self.shape, dtype=dtype, format='csr') + A = poisson(self.shape, dtype=dtype, format="csr") b = np.zeros(A.shape[0], dtype=A.dtype) # Now apply the boundary conditions @@ -1355,18 +1404,19 @@ def topyamg(self, dtype=None): @classmethod def _ArgumentParser_args_single(cls): - """ Returns the options for `Grid.ArgumentParser` in case they are the only options """ - return {'limit_arguments': False, - 'short': True, - 'positional_out': True, - } + """Returns the options for `Grid.ArgumentParser` in case they are the only options""" + return { + "limit_arguments": False, + "short": True, + "positional_out": True, + } # Hook into the Grid class to create # an automatic ArgumentParser which makes actions # as the options are read. @default_ArgumentParser(description="Manipulate a Grid object in sisl.") def ArgumentParser(self, p=None, *args, **kwargs): - """ Create and return a group of argument parsers which manipulates it self `Grid`. + """Create and return a group of argument parsers which manipulates it self `Grid`. Parameters ---------- @@ -1380,7 +1430,7 @@ def ArgumentParser(self, p=None, *args, **kwargs): positional_out : bool, False If `True`, adds a positional argument which acts as --out. This may be handy if only the geometry is in the argument list. """ - short = kwargs.get('short', False) + short = kwargs.get("short", False) def opts(*args): if short: @@ -1402,49 +1452,60 @@ def opts(*args): # Define actions class SetGeometry(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): ns._geometry = Geometry.read(value) ns._grid.set_geometry(ns._geometry) - p.add_argument(*opts('--geometry', '-G'), action=SetGeometry, - help='Define the geometry attached to the Grid.') + + p.add_argument( + *opts("--geometry", "-G"), + action=SetGeometry, + help="Define the geometry attached to the Grid.", + ) # subtract another grid # They *MUST* be comensurate. class DiffGrid(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): grid = Grid.read(value) ns._grid -= grid - p.add_argument(*opts('--diff', '-d'), action=DiffGrid, - help='Subtract another grid (they must be commensurate).') - class AverageGrid(argparse.Action): + p.add_argument( + *opts("--diff", "-d"), + action=DiffGrid, + help="Subtract another grid (they must be commensurate).", + ) + class AverageGrid(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._grid = ns._grid.average(direction(value)) - p.add_argument(*opts('--average'), metavar='DIR', - action=AverageGrid, - help='Take the average of the grid along DIR.') - class SumGrid(argparse.Action): + p.add_argument( + *opts("--average"), + metavar="DIR", + action=AverageGrid, + help="Take the average of the grid along DIR.", + ) + class SumGrid(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._grid = ns._grid.sum(direction(value)) - p.add_argument(*opts('--sum'), metavar='DIR', - action=SumGrid, - help='Take the sum of the grid along DIR.') + + p.add_argument( + *opts("--sum"), + metavar="DIR", + action=SumGrid, + help="Take the sum of the grid along DIR.", + ) # Create-subsets of the grid class SubDirectionGrid(argparse.Action): - def __call__(self, parser, ns, values, option_string=None): # The unit-cell direction axis = direction(values[1]) # Figure out whether this is a fractional or # distance in Ang - is_frac = 'f' in values[0] - rng = strseq(float, values[0].replace('f', '')) + is_frac = "f" in values[0] + rng = strseq(float, values[0].replace("f", "")) if isinstance(rng, tuple): if is_frac: rng = tuple(rng) @@ -1459,7 +1520,7 @@ def __call__(self, parser, ns, values, option_string=None): idx2 = ns._grid.index(rng[1], axis=axis) ns._grid = ns._grid.sub(_a.arangei(idx1, idx2), axis) return - elif rng < 0.: + elif rng < 0.0: if is_frac: rng = ns._grid.cell[axis, :] * abs(rng) b = False @@ -1469,24 +1530,30 @@ def __call__(self, parser, ns, values, option_string=None): b = True idx = ns._grid.index(rng, axis=axis) ns._grid = ns._grid.sub_part(idx, axis, b) - p.add_argument(*opts('--sub'), nargs=2, metavar=('COORD', 'DIR'), - action=SubDirectionGrid, - help='Reduce the grid by taking a subset of the grid (along DIR).') + + p.add_argument( + *opts("--sub"), + nargs=2, + metavar=("COORD", "DIR"), + action=SubDirectionGrid, + help="Reduce the grid by taking a subset of the grid (along DIR).", + ) # Create-subsets of the grid class RemoveDirectionGrid(argparse.Action): - def __call__(self, parser, ns, values, option_string=None): # The unit-cell direction axis = direction(values[1]) # Figure out whether this is a fractional or # distance in Ang - is_frac = 'f' in values[0] - rng = strseq(float, values[0].replace('f', '')) + is_frac = "f" in values[0] + rng = strseq(float, values[0].replace("f", "")) if isinstance(rng, tuple): # we have bounds if not (rng[0] is None or rng[1] is None): - raise NotImplementedError('Can not figure out how to apply mid-removal of grids.') + raise NotImplementedError( + "Can not figure out how to apply mid-removal of grids." + ) if rng[0] is None: idx1 = 0 else: @@ -1497,7 +1564,7 @@ def __call__(self, parser, ns, values, option_string=None): idx2 = ns._grid.index(rng[1], axis=axis) ns._grid = ns._grid.remove(_a.arangei(idx1, idx2), axis) return - elif rng < 0.: + elif rng < 0.0: if is_frac: rng = ns._grid.cell[axis, :] * abs(rng) b = True @@ -1507,76 +1574,100 @@ def __call__(self, parser, ns, values, option_string=None): b = False idx = ns._grid.index(rng, axis=axis) ns._grid = ns._grid.remove_part(idx, axis, b) - p.add_argument(*opts('--remove'), nargs=2, metavar=('COORD', 'DIR'), - action=RemoveDirectionGrid, - help='Reduce the grid by removing a subset of the grid (along DIR).') - class Tile(argparse.Action): + p.add_argument( + *opts("--remove"), + nargs=2, + metavar=("COORD", "DIR"), + action=RemoveDirectionGrid, + help="Reduce the grid by removing a subset of the grid (along DIR).", + ) + class Tile(argparse.Action): def __call__(self, parser, ns, values, option_string=None): r = int(values[0]) d = direction(values[1]) ns._grid = ns._grid.tile(r, d) - p.add_argument(*opts('--tile'), nargs=2, metavar=('TIMES', 'DIR'), - action=Tile, - help='Tiles the grid in the specified direction.') + + p.add_argument( + *opts("--tile"), + nargs=2, + metavar=("TIMES", "DIR"), + action=Tile, + help="Tiles the grid in the specified direction.", + ) # Scale the grid with this value class ScaleGrid(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): ns._grid.grid *= value - p.add_argument(*opts('--scale', '-S'), type=float, - action=ScaleGrid, - help='Scale grid values with a factor') + + p.add_argument( + *opts("--scale", "-S"), + type=float, + action=ScaleGrid, + help="Scale grid values with a factor", + ) # Define size of grid class InterpGrid(argparse.Action): - def __call__(self, parser, ns, values, option_string=None): def _conv_shape(length, value): if "." in value: return int(round(length / float(value))) return int(value) + shape = list(map(_conv_shape, ns._grid.lattice.length, values[:3])) # shorten list for easier arguments values = values[3:] if len(values) > 0: values[0] = int(values[0]) ns._grid = ns._grid.interp(shape, *values) - p.add_argument(*opts('--interp'), nargs='*', metavar='NX NY NZ *ORDER *MODE', - action=InterpGrid, - help="""Interpolate grid for higher or lower density (minimum 3 arguments) + + p.add_argument( + *opts("--interp"), + nargs="*", + metavar="NX NY NZ *ORDER *MODE", + action=InterpGrid, + help="""Interpolate grid for higher or lower density (minimum 3 arguments) Requires at least 3 arguments, number of points along 1st, 2nd and 3rd lattice vector. These may contain a "." to signal a distance in angstrom of each voxel. For instance --interp 0.1 10 100 will result in an interpolated shape of [nint(grid.lattice.length / 0.1), 10, 100]. The 4th optional argument is the order of interpolation; an integer 0<=i<=5 (default 1) The 5th optional argument is the mode to interpolate; wrap/mirror/constant/reflect/nearest -""") +""", + ) # Smoothen the grid class SmoothGrid(argparse.Action): - def __call__(self, parser, ns, values, option_string=None): if len(values) > 0: values[0] = float(values[0]) ns._grid = ns._grid.smooth(*values) - p.add_argument(*opts('--smooth'), nargs='*', metavar='*R *METHOD *MODE', - action=SmoothGrid, - help="""Smoothen grid values according to methods by applying a filter, all arguments are optional. + + p.add_argument( + *opts("--smooth"), + nargs="*", + metavar="*R *METHOD *MODE", + action=SmoothGrid, + help="""Smoothen grid values according to methods by applying a filter, all arguments are optional. The 1st argument is the radius of the filter for smoothening, a larger value means a larger volume which is agglomerated The 2nd argument is the method to use; gaussian/uniform The 3rd argument is the mode to use; wrap/mirror/constant/reflect/nearest -""") +""", + ) class PrintInfo(argparse.Action): - def __call__(self, parser, ns, values, option_string=None): ns._stored_grid = True print(ns._grid) - p.add_argument(*opts('--info'), nargs=0, - action=PrintInfo, - help='Print, to stdout, some regular information about the grid.') + + p.add_argument( + *opts("--info"), + nargs=0, + action=PrintInfo, + help="Print, to stdout, some regular information about the grid.", + ) class Plot(argparse.Action): def __call__(self, parser, ns, values, option_string=None): @@ -1590,7 +1681,11 @@ def __call__(self, parser, ns, values, option_string=None): for ax in (0, 1, 2): shape = grid.shape[ax] if shape > 1: - axs.append(np.linspace(0, grid.lattice.length[ax], shape, endpoint=False)) + axs.append( + np.linspace( + 0, grid.lattice.length[ax], shape, endpoint=False + ) + ) idx.append(ax) # Now plot data @@ -1607,24 +1702,27 @@ def __call__(self, parser, ns, values, option_string=None): plt.ylabel(f"Arbitrary unit") plt.show() - p.add_argument(*opts('--plot', '-P'), nargs=0, - action=Plot, - help='Plot the grid (currently only enabled if at least one dimension has been averaged out') + p.add_argument( + *opts("--plot", "-P"), + nargs=0, + action=Plot, + help="Plot the grid (currently only enabled if at least one dimension has been averaged out", + ) class Out(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): if value is None: return if len(value) == 0: return from sisl.io import get_sile + grid = ns._grid # determine whether the write-out file has *write_grid* as a methd # if not, and the grid only have 1 dimension, we allow it to be # written to a datafile - sile = get_sile(value[0], 'w') + sile = get_sile(value[0], "w") if hasattr(sile, "write_grid"): grid.write(sile) elif np.prod(grid.shape) == np.amax(grid.shape): @@ -1633,25 +1731,38 @@ def __call__(self, parser, ns, value, option_string=None): # the distance along the lattice vector idx = np.argmax(grid.shape) - dx = np.linspace(0, grid.lattice.length[idx], grid.shape[idx], endpoint=False) + dx = np.linspace( + 0, grid.lattice.length[idx], grid.shape[idx], endpoint=False + ) sile.write_data(dx, grid.grid.ravel()) else: - raise ValueError(f"""Either of these two cases are not fullfilled: + raise ValueError( + f"""Either of these two cases are not fullfilled: 1. {sile} do not have the `write_grid` method -2. The grid is not 1D data; averaged or summed along 2 directions.""") +2. The grid is not 1D data; averaged or summed along 2 directions.""" + ) # Issue to the namespace that the grid has been written, at least once. ns._stored_grid = True - p.add_argument(*opts('--out', '-o'), nargs=1, action=Out, - help='Store the grid (at its current invocation) to the out file.') + + p.add_argument( + *opts("--out", "-o"), + nargs=1, + action=Out, + help="Store the grid (at its current invocation) to the out file.", + ) # If the user requests positional out arguments, we also add that. - if kwargs.get('positional_out', False): - p.add_argument('out', nargs='*', default=None, - action=Out, - help='Store the grid (at its current invocation) to the out file.') + if kwargs.get("positional_out", False): + p.add_argument( + "out", + nargs="*", + default=None, + action=Out, + help="Store the grid (at its current invocation) to the out file.", + ) # We have now created all arguments return p, namespace @@ -1659,7 +1770,7 @@ def __call__(self, parser, ns, value, option_string=None): @set_module("sisl") def sgrid(grid=None, argv=None, ret_grid=False): - """ Main script for sgrid. + """Main script for sgrid. This routine may be called with `argv` and/or a `Sile` which is the grid at hand. @@ -1702,20 +1813,22 @@ def sgrid(grid=None, argv=None, ret_grid=False): if argv is not None: if len(argv) == 0: - argv = ['--help'] + argv = ["--help"] elif len(sys.argv) == 1: # no arguments # fake a help - argv = ['--help'] + argv = ["--help"] else: argv = sys.argv[1:] # Ensure that the arguments have pre-pended spaces argv = cmd.argv_negative_fix(argv) - p = argparse.ArgumentParser(exe, - formatter_class=argparse.RawDescriptionHelpFormatter, - description=description) + p = argparse.ArgumentParser( + exe, + formatter_class=argparse.RawDescriptionHelpFormatter, + description=description, + ) # Add default sisl version stuff cmd.add_sisl_version_cite_arg(p) @@ -1724,6 +1837,7 @@ def sgrid(grid=None, argv=None, ret_grid=False): stdout_grid = True if grid is None: from os.path import isfile + argv, input_file = cmd.collect_input(argv) kwargs = {} @@ -1745,15 +1859,15 @@ def sgrid(grid=None, argv=None, ret_grid=False): # and we will sort out if the input options # is only a help option. try: - if not hasattr(ns, '_input_file'): - setattr(ns, '_input_file', input_file) + if not hasattr(ns, "_input_file"): + setattr(ns, "_input_file", input_file) except Exception: pass # Now try and figure out the actual arguments - p, ns, argv = cmd.collect_arguments(argv, input=False, - argumentparser=p, - namespace=ns) + p, ns, argv = cmd.collect_arguments( + argv, input=False, argumentparser=p, namespace=ns + ) # We are good to go!!! args = p.parse_args(argv, namespace=ns) diff --git a/src/sisl/io/_help.py b/src/sisl/io/_help.py index 44bc1de74b..1acc5aae5e 100644 --- a/src/sisl/io/_help.py +++ b/src/sisl/io/_help.py @@ -5,8 +5,7 @@ import numpy as np -__all__ = ["starts_with_list", "header_to_dict", - "grid_reduce_indices"] +__all__ = ["starts_with_list", "header_to_dict", "grid_reduce_indices"] def starts_with_list(l, comments): @@ -17,7 +16,7 @@ def starts_with_list(l, comments): def header_to_dict(header): - """ Convert a header line with 'key=val key1=val1' sequences to a single dictionary """ + """Convert a header line with 'key=val key1=val1' sequences to a single dictionary""" e = re_compile(r"(\S+)=") # 1. Remove *any* entry with 0 length @@ -29,20 +28,22 @@ def header_to_dict(header): d = {} while len(kv) >= 2: # We have reversed the list - key = kv.pop().strip(' =') # remove white-space *and* = - val = kv.pop().strip() # remove outer whitespace + key = kv.pop().strip(" =") # remove white-space *and* = + val = kv.pop().strip() # remove outer whitespace d[key] = val return d def grid_reduce_indices(grids, factors, axis=0, out=None): - """ Reduce `grids` into a single `grid` value along `axis` by summing the `factors` + """Reduce `grids` into a single `grid` value along `axis` by summing the `factors` If `out` is defined the data will be stored there. """ if len(factors) > grids.shape[axis]: - raise ValueError(f"Trying to reduce a grid with too many factors: {len(factors)} > {grids.shape[axis]}") + raise ValueError( + f"Trying to reduce a grid with too many factors: {len(factors)} > {grids.shape[axis]}" + ) if out is not None: grid = out diff --git a/src/sisl/io/_multiple.py b/src/sisl/io/_multiple.py index ccf0b62841..89b69d0021 100644 --- a/src/sisl/io/_multiple.py +++ b/src/sisl/io/_multiple.py @@ -10,7 +10,7 @@ class SileSlicer: - """ Handling io-methods in sliced behaviour for multiple returns + """Handling io-methods in sliced behaviour for multiple returns This class handler can expose a slicing behavior of the function that it applies to. @@ -19,13 +19,16 @@ class SileSlicer: let this perform the function at hand for slicing behaviour etc. """ - def __init__(self, - obj: Type[Any], - func: Func, - key: Type[Any], - *, - skip_func: Optional[Func]=None, - postprocess: Optional[Callable[..., Any]]=None): + + def __init__( + self, + obj: Type[Any], + func: Func, + key: Type[Any], + *, + skip_func: Optional[Func] = None, + postprocess: Optional[Callable[..., Any]] = None, + ): # this makes it work like a function bound to an instance (func._obj # works for instances) self._obj = obj @@ -37,12 +40,14 @@ def __init__(self, else: self.skip_func = skip_func if postprocess is None: + def postprocess(ret): return ret + self.postprocess = postprocess def __call__(self, *args, **kwargs): - """ Defer call to the function """ + """Defer call to the function""" # Now handle the arguments obj = self._obj func = self.__wrapped__ @@ -68,7 +73,7 @@ def check_none(r): if key >= 0: start = key stop = key + 1 - elif key.step is None or key.step > 0: # step size of 1 + elif key.step is None or key.step > 0: # step size of 1 if key.start is not None: start = key.start if key.stop is not None: @@ -88,7 +93,7 @@ def check_none(r): # collect returning values retvals = [None] * start append = retvals.append - with obj: # open sile + with obj: # open sile # quick-skip using the skip-function for _ in range(start): skip_func(obj, *args, **kwargs) @@ -114,24 +119,27 @@ def check_none(r): self.key = None if isinstance(key, Integral): return retvals[key] - + # else postprocess return self.postprocess(retvals[key]) class SileBound: - """ A bound method deferring stuff to the function + """A bound method deferring stuff to the function This class calls the function `func` when directly called but returns the `slicer` class when users slices this object. """ - def __init__(self, - obj: Type[Any], - func: Callable[..., Any], - *, - slicer: Type[SileSlicer]=SileSlicer, - default_slice: Optional[Any]=None, - **kwargs): + + def __init__( + self, + obj: Type[Any], + func: Callable[..., Any], + *, + slicer: Type[SileSlicer] = SileSlicer, + default_slice: Optional[Any] = None, + **kwargs, + ): self._obj = obj # first update to the wrapped function update_wrapper(self, func) @@ -168,7 +176,8 @@ def _update_doc(self): docs = [doc] docs.append( - dedent(f""" + dedent( + f""" Notes ----- This method defaults to return {default_slice} item(s). @@ -190,7 +199,8 @@ def _update_doc(self): While one can store the sliced function ``tmp = obj.{name}[:]`` one will loose the slice after each call. - """) + """ + ) ) doc = "\n".join(docs) try: @@ -206,31 +216,27 @@ def __call__(self, *args, **kwargs): return self[self.default_slice](*args, **kwargs) def __getitem__(self, key): - """Extract sub items of multiple function calls as an indexed list """ - return self.slicer( - obj=self._obj, - func=self.__wrapped__, - key=key, - **self.kwargs - ) + """Extract sub items of multiple function calls as an indexed list""" + return self.slicer(obj=self._obj, func=self.__wrapped__, key=key, **self.kwargs) @property def next(self): - """Return the first element of the contained function """ + """Return the first element of the contained function""" return self[0] @property def last(self): - """Return the last element of the contained function """ + """Return the last element of the contained function""" return self[-1] class SileBinder: - """ Bind a class instance to the function name it decorates + """Bind a class instance to the function name it decorates Enables to bypass a class method with another object to defer handling in specific cases. """ + def __init__(self, **kwargs): self.kwargs = kwargs @@ -248,20 +254,11 @@ def __get__(self, obj, objtype=None): # and other things, this one won't bind # the SileBound object to the function # name it arrived from. - bound = SileBound( - obj=objtype, - func=func, - **self.kwargs - ) + bound = SileBound(obj=objtype, func=func, **self.kwargs) else: - bound = SileBound( - obj=obj, - func=func, - **self.kwargs - ) + bound = SileBound(obj=obj, func=func, **self.kwargs) # bind the class object to the host # No more instantiation setattr(obj, func.__name__, bound) return bound - diff --git a/src/sisl/io/bigdft/__init__.py b/src/sisl/io/bigdft/__init__.py index fc36568941..f0903a5953 100644 --- a/src/sisl/io/bigdft/__init__.py +++ b/src/sisl/io/bigdft/__init__.py @@ -8,5 +8,5 @@ asciiSileBigDFT - the input for BigDFT """ -from .sile import * # isort: split +from .sile import * # isort: split from .ascii import * diff --git a/src/sisl/io/bigdft/ascii.py b/src/sisl/io/bigdft/ascii.py index d4473b619d..a3d3a2888f 100644 --- a/src/sisl/io/bigdft/ascii.py +++ b/src/sisl/io/bigdft/ascii.py @@ -17,21 +17,21 @@ __all__ = ["asciiSileBigDFT"] -Bohr2Ang = unit_convert('Bohr', 'Ang') +Bohr2Ang = unit_convert("Bohr", "Ang") @set_module("sisl.io.bigdft") class asciiSileBigDFT(SileBigDFT): - """ ASCII file object for BigDFT """ + """ASCII file object for BigDFT""" def _setup(self, *args, **kwargs): - """ Initialize for `asciiSileBigDFT` """ + """Initialize for `asciiSileBigDFT`""" super()._setup(*args, **kwargs) - self._comment = ['#', '!'] + self._comment = ["#", "!"] @sile_fh_open() def read_geometry(self): - """ Reads a supercell from the Sile """ + """Reads a supercell from the Sile""" # 1st line is arbitrary self.readline(True) @@ -51,29 +51,27 @@ def read_geometry(self): # Now we need to read through and find keywords try: while True: - # Read line also with comment l = self.readline(True) # Empty line, means EOF - if l == '': + if l == "": break # too short, continue if len(l) < 1: continue # Check for keyword - if l[1:].startswith('keyword:'): - if 'reduced' in l: + if l[1:].startswith("keyword:"): + if "reduced" in l: is_frac = True - if 'angdeg' in l: + if "angdeg" in l: is_angdeg = True - if 'bohr' in l or 'atomic' in l: + if "bohr" in l or "atomic" in l: is_bohr = True continue elif l[0] in self._comment: - # this is a comment, cycle continue @@ -111,7 +109,7 @@ def read_geometry(self): # The input is in skewed axis lattice = Lattice([dxx, dyx, dyy, dzx, dzy, dzz]) else: - lattice = Lattice([[dxx, 0., 0.], [dyx, dyy, 0.], [dzx, dzy, dzz]]) + lattice = Lattice([[dxx, 0.0, 0.0], [dyx, dyy, 0.0], [dzx, dzy, dzz]]) # Now create the geometry xyz = np.array(xyz, np.float64) @@ -130,27 +128,25 @@ def read_geometry(self): return Geometry(xyz, spec, lattice=lattice) @sile_fh_open() - def write_geometry(self, geom, fmt='.8f'): - """ Writes the geometry to the contained file """ + def write_geometry(self, geom, fmt=".8f"): + """Writes the geometry to the contained file""" # Check that we can write to the file sile_raise_write(self) # Write out the cell - self._write('# Created by sisl\n') + self._write("# Created by sisl\n") # We write the cell coordinates as the cell coordinates - fmt_str = f'{{:{fmt}}} ' * 3 + '\n' - self._write( - fmt_str.format( - geom.cell[0, 0], geom.cell[1, 0], geom.cell[1, 1])) + fmt_str = f"{{:{fmt}}} " * 3 + "\n" + self._write(fmt_str.format(geom.cell[0, 0], geom.cell[1, 0], geom.cell[1, 1])) self._write(fmt_str.format(*geom.cell[2, :])) # This also denotes - self._write('#keyword: angstroem\n') + self._write("#keyword: angstroem\n") - self._write('# Geometry containing: ' + str(len(geom)) + ' atoms\n') + self._write("# Geometry containing: " + str(len(geom)) + " atoms\n") - f1_str = '{{1:{0}}} {{2:{0}}} {{3:{0}}} {{0:2s}}\n'.format(fmt) - f2_str = '{{2:{0}}} {{3:{0}}} {{4:{0}}} {{0:2s}} {{1:s}}\n'.format(fmt) + f1_str = "{{1:{0}}} {{2:{0}}} {{3:{0}}} {{0:2s}}\n".format(fmt) + f2_str = "{{2:{0}}} {{3:{0}}} {{4:{0}}} {{0:2s}} {{1:s}}\n".format(fmt) for ia, a, _ in geom.iter_species(): if a.symbol != a.tag: @@ -158,10 +154,10 @@ def write_geometry(self, geom, fmt='.8f'): else: self._write(f1_str.format(a.symbol, *geom.xyz[ia, :])) # Add a single new line - self._write('\n') + self._write("\n") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) diff --git a/src/sisl/io/cube.py b/src/sisl/io/cube.py index 3b78e6dd8e..cf20473591 100644 --- a/src/sisl/io/cube.py +++ b/src/sisl/io/cube.py @@ -5,6 +5,7 @@ from sisl import Atom, Geometry, Grid, Lattice, SislError from sisl._internal import set_module + # Import sile objects from sisl.io.sile import * from sisl.messages import deprecate_argument @@ -17,7 +18,7 @@ @set_module("sisl.io") class cubeSile(Sile): - """ CUBE file object + """CUBE file object By default the cube file is written using Bohr units. one can define the units by passing a respective unit argument. @@ -26,11 +27,20 @@ class cubeSile(Sile): """ @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") - def write_lattice(self, lattice, fmt="15.10e", size=None, origin=None, - unit="Bohr", - *args, **kwargs): - """ Writes `Lattice` object attached to this grid + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) + def write_lattice( + self, + lattice, + fmt="15.10e", + size=None, + origin=None, + unit="Bohr", + *args, + **kwargs, + ): + """Writes `Lattice` object attached to this grid Parameters ---------- @@ -71,10 +81,17 @@ def write_lattice(self, lattice, fmt="15.10e", size=None, origin=None, self._write("1 0. 0. 0. 0.\n") @sile_fh_open() - def write_geometry(self, geometry, fmt="15.10e", size=None, origin=None, - unit="Bohr", - *args, **kwargs): - """ Writes `Geometry` object attached to this grid + def write_geometry( + self, + geometry, + fmt="15.10e", + size=None, + origin=None, + unit="Bohr", + *args, + **kwargs, + ): + """Writes `Geometry` object attached to this grid Parameters ---------- @@ -118,11 +135,13 @@ def write_geometry(self, geometry, fmt="15.10e", size=None, origin=None, tmp = " {:" + fmt + "}" _fmt = "{:d} 0.0" + tmp + tmp + tmp + "\n" for ia in geometry: - self._write(_fmt.format(geometry.atoms[ia].Z, *geometry.xyz[ia, :] * Ang2unit)) + self._write( + _fmt.format(geometry.atoms[ia].Z, *geometry.xyz[ia, :] * Ang2unit) + ) @sile_fh_open() def write_grid(self, grid, fmt=".5e", imag=False, unit="Bohr", *args, **kwargs): - """ Write `Grid` to the contained file + """Write `Grid` to the contained file Parameters ---------- @@ -144,12 +163,16 @@ def write_grid(self, grid, fmt=".5e", imag=False, unit="Bohr", *args, **kwargs): sile_raise_write(self) if grid.geometry is None: - self.write_lattice(grid.lattice, size=grid.shape, unit=unit, *args, **kwargs) + self.write_lattice( + grid.lattice, size=grid.shape, unit=unit, *args, **kwargs + ) else: - self.write_geometry(grid.geometry, size=grid.shape, unit=unit, *args, **kwargs) + self.write_geometry( + grid.geometry, size=grid.shape, unit=unit, *args, **kwargs + ) buffersize = kwargs.get("buffersize", min(6144, grid.grid.size)) - buffersize += buffersize % 6 # ensure multiple of 6 + buffersize += buffersize % 6 # ensure multiple of 6 # A CUBE file contains grid-points aligned like this: # for x @@ -161,15 +184,25 @@ def write_grid(self, grid, fmt=".5e", imag=False, unit="Bohr", *args, **kwargs): __fmt = _fmt6 * (buffersize // 6) if imag: - for z in np.nditer(np.asarray(grid.grid.imag, order="C").reshape(-1), flags=["external_loop", "buffered"], - op_flags=[["readonly"]], order="C", buffersize=buffersize): + for z in np.nditer( + np.asarray(grid.grid.imag, order="C").reshape(-1), + flags=["external_loop", "buffered"], + op_flags=[["readonly"]], + order="C", + buffersize=buffersize, + ): if z.shape[0] != buffersize: s = z.shape[0] __fmt = _fmt6 * (s // 6) + _fmt1 * (s % 6) + "\n" self._write(__fmt.format(*z.tolist())) else: - for z in np.nditer(np.asarray(grid.grid.real, order="C").reshape(-1), flags=["external_loop", "buffered"], - op_flags=[["readonly"]], order="C", buffersize=buffersize): + for z in np.nditer( + np.asarray(grid.grid.real, order="C").reshape(-1), + flags=["external_loop", "buffered"], + op_flags=[["readonly"]], + order="C", + buffersize=buffersize, + ): if z.shape[0] != buffersize: s = z.shape[0] __fmt = _fmt6 * (s // 6) + _fmt1 * (s % 6) + "\n" @@ -179,7 +212,7 @@ def write_grid(self, grid, fmt=".5e", imag=False, unit="Bohr", *args, **kwargs): self._write("\n") def _r_header_dict(self): - """ Reads the header of the file """ + """Reads the header of the file""" self.fh.seek(0) self.readline() header = header_to_dict(self.readline()) @@ -188,7 +221,7 @@ def _r_header_dict(self): @sile_fh_open() def read_lattice(self, na=False): - """ Returns `Lattice` object from the CUBE file + """Returns `Lattice` object from the CUBE file Parameters ---------- @@ -197,7 +230,7 @@ def read_lattice(self, na=False): """ unit2Ang = self._r_header_dict()["unit"] - origin = self.readline().split() # origin + origin = self.readline().split() # origin lna = int(origin[0]) origin = np.fromiter(map(float, origin[1:]), np.float64) @@ -217,7 +250,7 @@ def read_lattice(self, na=False): @sile_fh_open() def read_geometry(self): - """ Returns `Geometry` object from the CUBE file """ + """Returns `Geometry` object from the CUBE file""" unit2Ang = self._r_header_dict()["unit"] na, lattice = self.read_lattice(na=True) @@ -238,7 +271,7 @@ def read_geometry(self): @sile_fh_open() def read_grid(self, imag=None): - """ Returns `Grid` object from the CUBE file + """Returns `Grid` object from the CUBE file Parameters ---------- @@ -288,11 +321,15 @@ def read_grid(self, imag=None): # We are expecting an imaginary part if not grid.geometry.equal(imag.geometry): - raise SislError(f"{self!s} and its imaginary part does not have the same " - "geometry. Hence a combined complex Grid cannot be formed.") + raise SislError( + f"{self!s} and its imaginary part does not have the same " + "geometry. Hence a combined complex Grid cannot be formed." + ) if grid != imag: - raise SislError(f"{self!s} and its imaginary part does not have the same " - "shape. Hence a combined complex Grid cannot be formed.") + raise SislError( + f"{self!s} and its imaginary part does not have the same " + "shape. Hence a combined complex Grid cannot be formed." + ) # Now we have a complex grid grid.grid = grid.grid + 1j * imag.grid diff --git a/src/sisl/io/fhiaims/__init__.py b/src/sisl/io/fhiaims/__init__.py index 1fdb4364f6..31d10f9ae5 100644 --- a/src/sisl/io/fhiaims/__init__.py +++ b/src/sisl/io/fhiaims/__init__.py @@ -8,5 +8,5 @@ inSileFHIaims - geometry for FHI-aims """ -from .sile import * # isort: split +from .sile import * # isort: split from ._geometry import * diff --git a/src/sisl/io/fhiaims/_geometry.py b/src/sisl/io/fhiaims/_geometry.py index 4eccf7d33b..b3f34efeaa 100644 --- a/src/sisl/io/fhiaims/_geometry.py +++ b/src/sisl/io/fhiaims/_geometry.py @@ -18,12 +18,14 @@ @set_module("sisl.io.fhiaims") class inSileFHIaims(SileFHIaims): - """ FHI-aims geometry file object """ + """FHI-aims geometry file object""" @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def write_lattice(self, lattice, fmt=".8f"): - """ Writes the supercell to the contained file + """Writes the supercell to the contained file Parameters ---------- @@ -39,8 +41,10 @@ def write_lattice(self, lattice, fmt=".8f"): self._write(_fmt.format(*lattice.cell[2])) @sile_fh_open() - def write_geometry(self, geometry, fmt=".8f", as_frac=False, velocity=None, moment=None): - """ Writes the geometry to the contained file + def write_geometry( + self, geometry, fmt=".8f", as_frac=False, velocity=None, moment=None + ): + """Writes the geometry to the contained file Parameters ---------- @@ -80,7 +84,7 @@ def write_geometry(self, geometry, fmt=".8f", as_frac=False, velocity=None, mome @sile_fh_open() def read_lattice(self): - """ Reads supercell object from the file """ + """Reads supercell object from the file""" self.fh.seek(0) # read until "lattice_vector" is found @@ -93,7 +97,7 @@ def read_lattice(self): @sile_fh_open() def read_geometry(self, velocity=False, moment=False): - """ Reads Geometry object from the file + """Reads Geometry object from the file Parameters ---------- @@ -124,7 +128,9 @@ def read_geometry(self, velocity=False, moment=False): def ensure_length(l, length, add): if length < 0: - raise SileError("Found a velocity/initial_moment entry before an atom entry?") + raise SileError( + "Found a velocity/initial_moment entry before an atom entry?" + ) while len(l) < length: l.append(add) @@ -150,7 +156,7 @@ def ensure_length(l, length, add): # we found an atom sp.append(line[4]) - ret = (Geometry(xyz, atoms=sp, lattice=lattice), ) + ret = (Geometry(xyz, atoms=sp, lattice=lattice),) if not velocity and not moment: return ret[0] @@ -161,15 +167,15 @@ def ensure_length(l, length, add): return ret def read_velocity(self): - """ Reads velocity in the file """ + """Reads velocity in the file""" return self.read_geometry(velocity=True)[1] def read_moment(self): - """ Reads initial moment in the file """ + """Reads initial moment in the file""" return self.read_geometry(moment=True)[1] def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) diff --git a/src/sisl/io/gulp/__init__.py b/src/sisl/io/gulp/__init__.py index a3ec69fe44..66e16a52c2 100644 --- a/src/sisl/io/gulp/__init__.py +++ b/src/sisl/io/gulp/__init__.py @@ -9,6 +9,6 @@ fcSileGULP - force constant output from GULP """ -from .sile import * # isort: split +from .sile import * # isort: split from .fc import * from .got import * diff --git a/src/sisl/io/gulp/fc.py b/src/sisl/io/gulp/fc.py index 94c7f19a7d..2ea0d5f0b3 100644 --- a/src/sisl/io/gulp/fc.py +++ b/src/sisl/io/gulp/fc.py @@ -13,16 +13,16 @@ from ..sile import add_sile, sile_fh_open from .sile import SileGULP -__all__ = ['fcSileGULP'] +__all__ = ["fcSileGULP"] @set_module("sisl.io.gulp") class fcSileGULP(SileGULP): - """ GULP output file object """ + """GULP output file object""" @sile_fh_open() def read_force_constant(self, **kwargs): - """ Returns a sparse matrix in coo format which contains the GULP force constant matrix. + """Returns a sparse matrix in coo format which contains the GULP force constant matrix. This routine expects the units to be in eV/Ang**2. @@ -38,8 +38,8 @@ def read_force_constant(self, **kwargs): FC : force constant in `scipy.sparse.coo_matrix` format """ # Default cutoff - cutoff = kwargs.get('cutoff', 0.) - dtype = kwargs.get('dtype', np.float64) + cutoff = kwargs.get("cutoff", 0.0) + dtype = kwargs.get("dtype", np.float64) # Read number of atoms in the file... na = int(self.readline()) @@ -54,7 +54,6 @@ def read_force_constant(self, **kwargs): for ia in range(na): i = ia * 3 for ja in range(na): - # read line that should contain: # ia ja lsplit = rl().split() @@ -68,15 +67,15 @@ def read_force_constant(self, **kwargs): # much faster assignment fc[i, :] = tmp[0].ravel() - fc[i+1, :] = tmp[1].ravel() - fc[i+2, :] = tmp[2].ravel() + fc[i + 1, :] = tmp[1].ravel() + fc[i + 2, :] = tmp[2].ravel() # Convert to COO format fc = fc.tocoo() - fc.data[np.fabs(fc.data) < cutoff] = 0. + fc.data[np.fabs(fc.data) < cutoff] = 0.0 fc.eliminate_zeros() return fc -add_sile('FORCE_CONSTANTS_2ND', fcSileGULP, gzip=True) +add_sile("FORCE_CONSTANTS_2ND", fcSileGULP, gzip=True) diff --git a/src/sisl/io/gulp/got.py b/src/sisl/io/gulp/got.py index 905158fdb3..f9e407b1c8 100644 --- a/src/sisl/io/gulp/got.py +++ b/src/sisl/io/gulp/got.py @@ -20,7 +20,7 @@ @set_module("sisl.io.gulp") class gotSileGULP(SileGULP): - """ GULP output file object + """GULP output file object Parameters ---------- @@ -33,34 +33,36 @@ class gotSileGULP(SileGULP): """ def _setup(self, *args, **kwargs): - """ Setup `gotSileGULP` after initialization """ + """Setup `gotSileGULP` after initialization""" super()._setup(*args, **kwargs) self._keys = dict() - self.set_lattice_key('Cartesian lattice vectors') - self.set_geometry_key('Final fractional coordinates') - self.set_dynamical_matrix_key('Real Dynamical matrix') + self.set_lattice_key("Cartesian lattice vectors") + self.set_geometry_key("Final fractional coordinates") + self.set_dynamical_matrix_key("Real Dynamical matrix") def set_key(self, segment, key): - """ Sets the segment lookup key """ + """Sets the segment lookup key""" if key is not None: self._keys[segment] = key def set_lattice_key(self, key): - """ Overwrites internal key lookup value for the cell vectors """ - self.set_key('lattice', key) + """Overwrites internal key lookup value for the cell vectors""" + self.set_key("lattice", key) - set_supercell_key = deprecation("set_supercell_key is deprecated in favor of set_lattice_key", "0.15")(set_lattice_key) + set_supercell_key = deprecation( + "set_supercell_key is deprecated in favor of set_lattice_key", "0.15" + )(set_lattice_key) @sile_fh_open() def read_lattice_nsc(self, key=None): - """ Reads the dimensions of the supercell """ + """Reads the dimensions of the supercell""" - f, l = self.step_to('Supercell dimensions') + f, l = self.step_to("Supercell dimensions") if not f: return np.array([1, 1, 1], np.int32) # Read off the supercell dimensions - xyz = l.split('=')[1:] + xyz = l.split("=")[1:] # Now read off the quantities... nsc = [int(i.split()[0]) for i in xyz] @@ -69,16 +71,17 @@ def read_lattice_nsc(self, key=None): @sile_fh_open() def read_lattice(self, key=None, **kwargs): - """ Reads a `Lattice` and creates the GULP cell """ + """Reads a `Lattice` and creates the GULP cell""" self.set_lattice_key(key) - f, _ = self.step_to(self._keys['lattice']) + f, _ = self.step_to(self._keys["lattice"]) if not f: raise ValueError( - 'SileGULP tries to lookup the Lattice vectors ' - 'using key "' + self._keys['lattice'] + '". \n' - 'Use ".set_lattice_key(...)" to search for different name.\n' - 'This could not be found found in file: "' + self.file + '".') + "SileGULP tries to lookup the Lattice vectors " + 'using key "' + self._keys["lattice"] + '". \n' + 'Use ".set_lattice_key(...)" to search for different name.\n' + 'This could not be found found in file: "' + self.file + '".' + ) # skip 1 line self.readline() @@ -90,23 +93,27 @@ def read_lattice(self, key=None, **kwargs): return Lattice(cell) def set_geometry_key(self, key): - """ Overwrites internal key lookup value for the geometry vectors """ - self.set_key('geometry', key) + """Overwrites internal key lookup value for the geometry vectors""" + self.set_key("geometry", key) @sile_fh_open() def read_geometry(self, **kwargs): - """ Reads a geometry and creates the GULP dynamical geometry """ + """Reads a geometry and creates the GULP dynamical geometry""" # create default supercell lattice = Lattice([1, 1, 1]) for _ in [0, 1]: # Step to either the geometry or - f, _, ki = self.step_to([self._keys['lattice'], self._keys['geometry']], ret_index=True) + f, _, ki = self.step_to( + [self._keys["lattice"], self._keys["geometry"]], ret_index=True + ) if not f and ki == 0: - raise ValueError('SileGULP tries to lookup the Lattice vectors ' - 'using key "' + self._keys['lattice'] + '". \n' - 'Use ".set_lattice_key(...)" to search for different name.\n' - 'This could not be found found in file: "' + self.file + '".') + raise ValueError( + "SileGULP tries to lookup the Lattice vectors " + 'using key "' + self._keys["lattice"] + '". \n' + 'Use ".set_lattice_key(...)" to search for different name.\n' + 'This could not be found found in file: "' + self.file + '".' + ) elif f and ki == 0: # supercell self.readline() @@ -119,13 +126,14 @@ def read_geometry(self, **kwargs): lattice = Lattice(cell) elif not f and ki == 1: - raise ValueError('SileGULP tries to lookup the Geometry coordinates ' - 'using key "' + self._keys['geometry'] + '". \n' - 'Use ".set_geom_key(...)" to search for different name.\n' - 'This could not be found found in file: "' + self.file + '".') + raise ValueError( + "SileGULP tries to lookup the Geometry coordinates " + 'using key "' + self._keys["geometry"] + '". \n' + 'Use ".set_geom_key(...)" to search for different name.\n' + 'This could not be found found in file: "' + self.file + '".' + ) elif f and ki == 1: - - orbs = [Orbital(-1, tag=tag) for tag in 'xyz'] + orbs = [Orbital(-1, tag=tag) for tag in "xyz"] # We skip 5 lines for _ in [0] * 5: @@ -135,7 +143,7 @@ def read_geometry(self, **kwargs): xyz = [] while True: l = self.readline() - if l[0] == '-': + if l[0] == "-": break ls = l.split() @@ -147,17 +155,21 @@ def read_geometry(self, **kwargs): xyz.shape = (-1, 3) if len(Z) == 0 or len(xyz) == 0: - raise ValueError('Could not read in cell information and/or coordinates') + raise ValueError( + "Could not read in cell information and/or coordinates" + ) elif not f: # could not find either cell or geometry - raise ValueError('SileGULP tries to lookup the Lattice or Geometry.\n' - 'None succeeded, ensure file has correct format.\n' - 'This could not be found found in file: "{}".'.format(self.file)) + raise ValueError( + "SileGULP tries to lookup the Lattice or Geometry.\n" + "None succeeded, ensure file has correct format.\n" + 'This could not be found found in file: "{}".'.format(self.file) + ) # as the cell may be read in after the geometry we have # to wait until here to convert from fractional - if 'fractional' in self._keys['geometry'].lower(): + if "fractional" in self._keys["geometry"].lower(): # Correct for fractional coordinates xyz = np.dot(xyz, lattice.cell) @@ -165,13 +177,13 @@ def read_geometry(self, **kwargs): return Geometry(xyz, Z, lattice=lattice) def set_dynamical_matrix_key(self, key): - """ Overwrites internal key lookup value for the dynamical matrix vectors """ - self.set_key('dyn', key) + """Overwrites internal key lookup value for the dynamical matrix vectors""" + self.set_key("dyn", key) set_dyn_key = set_dynamical_matrix_key def read_dynamical_matrix(self, **kwargs): - """ Returns a GULP dynamical matrix model for the output of GULP + """Returns a GULP dynamical matrix model for the output of GULP Parameters ---------- @@ -188,13 +200,17 @@ def read_dynamical_matrix(self, **kwargs): """ geom = self.read_geometry(**kwargs) - order = kwargs.pop('order', ['got', 'FC']) + order = kwargs.pop("order", ["got", "FC"]) for f in order: - v = getattr(self, '_r_dynamical_matrix_{}'.format(f.lower()))(geom, **kwargs) + v = getattr(self, "_r_dynamical_matrix_{}".format(f.lower()))( + geom, **kwargs + ) if v is not None: # Convert the dynamical matrix such that a diagonalization returns eV ^ 2 - scale = constant.hbar / units('Ang', 'm') / units('eV amu', 'J kg') ** 0.5 - v.data *= scale ** 2 + scale = ( + constant.hbar / units("Ang", "m") / units("eV amu", "J kg") ** 0.5 + ) + v.data *= scale**2 v = DynamicalMatrix.fromsp(geom, v) if kwargs.get("hermitian", True): v = (v + v.transpose()) * 0.5 @@ -204,23 +220,25 @@ def read_dynamical_matrix(self, **kwargs): @sile_fh_open() def _r_dynamical_matrix_got(self, geometry, **kwargs): - """ In case the dynamical matrix is read from the file """ + """In case the dynamical matrix is read from the file""" # Easier for creation of the sparsity pattern from scipy.sparse import lil_matrix # Default cutoff eV / Ang ** 2 - cutoff = kwargs.get('cutoff', 0.) - dtype = kwargs.get('dtype', np.float64) + cutoff = kwargs.get("cutoff", 0.0) + dtype = kwargs.get("dtype", np.float64) nxyz = geometry.no dyn = lil_matrix((nxyz, nxyz), dtype=dtype) - f, _ = self.step_to(self._keys['dyn']) + f, _ = self.step_to(self._keys["dyn"]) if not f: - info(f"{self.__class__.__name__}.read_dynamical_matrix tries to lookup the Dynamical matrix " - "using key '{self._keys['dyn']}'. " - "Use .set_dynamical_matrix_key(...) to search for different name." - "This could not be found found in file: {self.file}") + info( + f"{self.__class__.__name__}.read_dynamical_matrix tries to lookup the Dynamical matrix " + "using key '{self._keys['dyn']}'. " + "Use .set_dynamical_matrix_key(...) to search for different name." + "This could not be found found in file: {self.file}" + ) return None # skip 1 line @@ -243,7 +261,7 @@ def _r_dynamical_matrix_got(self, geometry, **kwargs): # GULP only prints columns corresponding # to a full row. Hence the remaining # data must be nxyz - j - 1 - dat[j:j + k] = ls[:k] + dat[j : j + k] = ls[:k] j += k if j >= nxyz: @@ -262,7 +280,7 @@ def _r_dynamical_matrix_got(self, geometry, **kwargs): # Construct mass ** (-.5), so we can check cutoff correctly (in unit eV/Ang**2) mass_sqrt = geometry.atoms.mass.repeat(3) ** 0.5 dyn.data[:] *= mass_sqrt[dyn.row] * mass_sqrt[dyn.col] - dyn.data[np.fabs(dyn.data) < cutoff] = 0. + dyn.data[np.fabs(dyn.data) < cutoff] = 0.0 dyn.data[:] *= 1 / (mass_sqrt[dyn.row] * mass_sqrt[dyn.col]) dyn.eliminate_zeros() @@ -273,17 +291,21 @@ def _r_dynamical_matrix_got(self, geometry, **kwargs): def _r_dynamical_matrix_fc(self, geometry, **kwargs): # The output of the force constant in the file does not contain the mass-scaling # nor the unit conversion - f = self.dir_file('FORCE_CONSTANTS_2ND') + f = self.dir_file("FORCE_CONSTANTS_2ND") if not f.is_file(): return None - fc = fcSileGULP(f, 'r').read_force_constant(**kwargs) + fc = fcSileGULP(f, "r").read_force_constant(**kwargs) if fc.shape[0] // 3 != geometry.na: - warn(f"{self.__class__.__name__}.read_dynamical_matrix(FC) inconsistent force constant file, na_file={fc.shape[0]//3}, na_geom={geometry.na}") + warn( + f"{self.__class__.__name__}.read_dynamical_matrix(FC) inconsistent force constant file, na_file={fc.shape[0]//3}, na_geom={geometry.na}" + ) return None elif fc.shape[0] != geometry.no: - warn(f"{self.__class__.__name__}.read_dynamical_matrix(FC) inconsistent geometry, no_file={fc.shape[0]}, no_geom={geometry.no}") + warn( + f"{self.__class__.__name__}.read_dynamical_matrix(FC) inconsistent geometry, no_file={fc.shape[0]}, no_geom={geometry.no}" + ) return None # Construct orbital mass ** (-.5) diff --git a/src/sisl/io/gulp/tests/test_gout.py b/src/sisl/io/gulp/tests/test_gout.py index c5040cfad9..31439a2887 100644 --- a/src/sisl/io/gulp/tests/test_gout.py +++ b/src/sisl/io/gulp/tests/test_gout.py @@ -11,13 +11,13 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.gulp] -_dir = osp.join('sisl', 'io', 'gulp') +_dir = osp.join("sisl", "io", "gulp") def test_zz_dynamical_matrix(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'zz.gout')) - D1 = si.read_dynamical_matrix(order=['got'], cutoff=1.e-4) - D2 = si.read_dynamical_matrix(order=['FC'], cutoff=1.e-4) + si = sisl.get_sile(sisl_files(_dir, "zz.gout")) + D1 = si.read_dynamical_matrix(order=["got"], cutoff=1.0e-4) + D2 = si.read_dynamical_matrix(order=["FC"], cutoff=1.0e-4) assert D1._csr.spsame(D2._csr) D1.finalize() @@ -26,7 +26,7 @@ def test_zz_dynamical_matrix(sisl_files): def test_zz_sc_geom(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'zz.gout')) + si = sisl.get_sile(sisl_files(_dir, "zz.gout")) lattice = si.read_lattice() geom = si.read_geometry() assert lattice == geom.lattice @@ -52,7 +52,7 @@ def test_graphene_8x8_untiling(sisl_files): assert np.allclose(d._csr._D, 0, atol=1e-6) # we can't assert that the sparsity patterns are the same # The differences are tiny, 1e-7 but they are there - #assert seg_y.spsame(seg) + # assert seg_y.spsame(seg) seg_x = segs_x.pop() for seg in segs_x: diff --git a/src/sisl/io/ham.py b/src/sisl/io/ham.py index c1528fa949..cec4b301cf 100644 --- a/src/sisl/io/ham.py +++ b/src/sisl/io/ham.py @@ -15,16 +15,16 @@ # Import sile objects from .sile import * -__all__ = ['hamiltonianSile'] +__all__ = ["hamiltonianSile"] @set_module("sisl.io") class hamiltonianSile(Sile): - """ Hamiltonian file object """ + """Hamiltonian file object""" @sile_fh_open() def read_geometry(self): - """ Reading a geometry in regular Hamiltonian format """ + """Reading a geometry in regular Hamiltonian format""" cell = np.zeros([3, 3], np.float64) Z = [] @@ -38,21 +38,21 @@ def Z2no(i, no): return int(i), no except Exception: # both atomic number and no - j = i.replace('[', ' ').replace(']', ' ').split() + j = i.replace("[", " ").replace("]", " ").split() return int(j[0]), int(j[1]) # The format of the geometry file is - keys = ['atoms', 'cell', 'supercell', 'nsc'] + keys = ["atoms", "cell", "supercell", "nsc"] for _ in range(len(keys)): _, l = self.step_to(keys, case=False) l = l.strip() - if 'supercell' in l.lower() or 'nsc' in l.lower(): + if "supercell" in l.lower() or "nsc" in l.lower(): # We have everything in one line l = l.split()[1:] for i in range(3): nsc[i] = int(l[i]) - elif 'cell' in l.lower(): - if 'begin' in l.lower(): + elif "cell" in l.lower(): + if "begin" in l.lower(): for i in range(3): l = self.readline().split() cell[i, 0] = float(l[0]) @@ -65,16 +65,16 @@ def Z2no(i, no): for i in range(3): cell[i, i] = float(l[i]) # TODO incorporate rotations - elif 'atoms' in l.lower(): + elif "atoms" in l.lower(): l = self.readline() - while not l.startswith('end'): + while not l.startswith("end"): ls = l.split() try: no = int(ls[4]) except Exception: no = 1 z, no = Z2no(ls[0], no) - Z.append({'Z': z, 'orbital': [-1. for _ in range(no)]}) + Z.append({"Z": z, "orbital": [-1.0 for _ in range(no)]}) xyz.append([float(f) for f in ls[1:4]]) l = self.readline() xyz = _a.arrayd(xyz) @@ -88,7 +88,7 @@ def Z2no(i, no): @sile_fh_open() def read_hamiltonian(self, hermitian=True, dtype=np.float64, **kwargs): - """ Reads a Hamiltonian (including the geometry) + """Reads a Hamiltonian (including the geometry) Reads the Hamiltonian model """ @@ -110,12 +110,12 @@ def i2o(geom, i): except Exception: # ia[o] # atom ia and the orbital o - j = i.replace('[', ' ').replace(']', ' ').split() + j = i.replace("[", " ").replace("]", " ").split() return geom.a2o(int(j[0])) + int(j[1]) # Start reading in the supercell while True: - found, l = self.step_to('matrix', allow_reread=False) + found, l = self.step_to("matrix", allow_reread=False) if not found: break @@ -129,7 +129,7 @@ def i2o(geom, i): off1 = geom.sc_index(isc) * geom.no off2 = geom.sc_index(-isc) * geom.no l = self.readline() - while not l.startswith('end'): + while not l.startswith("end"): ls = l.split() jo = i2o(geom, ls[0]) io = i2o(geom, ls[1]) @@ -137,7 +137,7 @@ def i2o(geom, i): try: s = float(ls[3]) except Exception: - s = 0. + s = 0.0 H[jo, io + off1] = h S[jo, io + off1] = s if hermitian: @@ -148,7 +148,7 @@ def i2o(geom, i): return Hamiltonian.fromsp(geom, H, S) @sile_fh_open() - def write_geometry(self, geometry, fmt='.8f', **kwargs): + def write_geometry(self, geometry, fmt=".8f", **kwargs): """ Writes the geometry to the output file @@ -162,25 +162,24 @@ def write_geometry(self, geometry, fmt='.8f', **kwargs): # for now, pretty stringent # Get cell_fmt cell_fmt = fmt - if 'cell_fmt' in kwargs: - cell_fmt = kwargs['cell_fmt'] + if "cell_fmt" in kwargs: + cell_fmt = kwargs["cell_fmt"] xyz_fmt = fmt - self._write('begin cell\n') + self._write("begin cell\n") # Write the cell - fmt_str = ' {{0:{0}}} {{1:{0}}} {{2:{0}}}\n'.format(cell_fmt) + fmt_str = " {{0:{0}}} {{1:{0}}} {{2:{0}}}\n".format(cell_fmt) for i in range(3): self._write(fmt_str.format(*geometry.cell[i, :])) - self._write('end cell\n') + self._write("end cell\n") # Write number of super cells in each direction - self._write('\nsupercell {:d} {:d} {:d}\n'.format(*geometry.nsc)) + self._write("\nsupercell {:d} {:d} {:d}\n".format(*geometry.nsc)) # Write all atomic positions along with the specie type - self._write('\nbegin atoms\n') - fmt1_str = ' {{0:d}} {{1:{0}}} {{2:{0}}} {{3:{0}}}\n'.format(xyz_fmt) - fmt2_str = ' {{0:d}}[{{1:d}}] {{2:{0}}} {{3:{0}}} {{4:{0}}}\n'.format( - xyz_fmt) + self._write("\nbegin atoms\n") + fmt1_str = " {{0:d}} {{1:{0}}} {{2:{0}}} {{3:{0}}}\n".format(xyz_fmt) + fmt2_str = " {{0:d}}[{{1:d}}] {{2:{0}}} {{3:{0}}} {{4:{0}}}\n".format(xyz_fmt) for ia in geometry: Z = geometry.atoms[ia].Z @@ -190,12 +189,12 @@ def write_geometry(self, geometry, fmt='.8f', **kwargs): else: self._write(fmt2_str.format(Z, no, *geometry.xyz[ia, :])) - self._write('end atoms\n') + self._write("end atoms\n") @wrap_filterwarnings("ignore", category=SparseEfficiencyWarning) @sile_fh_open() def write_hamiltonian(self, H, hermitian=True, **kwargs): - """ Writes the Hamiltonian model to the file + """Writes the Hamiltonian model to the file Writes a Hamiltonian model to the intrinsic Hamiltonian file format. The file can be constructed by the implict force of Hermiticity, @@ -220,18 +219,19 @@ def write_hamiltonian(self, H, hermitian=True, **kwargs): # We default to the advanced layuot if we have more than one # orbital on any one atom - advanced = kwargs.get('advanced', np.any( - np.array([a.no for a in geom.atoms.atom], np.int32) > 1)) + advanced = kwargs.get( + "advanced", np.any(np.array([a.no for a in geom.atoms.atom], np.int32) > 1) + ) - fmt = kwargs.get('fmt', 'g') + fmt = kwargs.get("fmt", "g") if advanced: - fmt1_str = ' {{0:d}}[{{1:d}}] {{2:d}}[{{3:d}}] {{4:{0}}}\n'.format( - fmt) - fmt2_str = ' {{0:d}}[{{1:d}}] {{2:d}}[{{3:d}}] {{4:{0}}} {{5:{0}}}\n'.format( - fmt) + fmt1_str = " {{0:d}}[{{1:d}}] {{2:d}}[{{3:d}}] {{4:{0}}}\n".format(fmt) + fmt2_str = ( + " {{0:d}}[{{1:d}}] {{2:d}}[{{3:d}}] {{4:{0}}} {{5:{0}}}\n".format(fmt) + ) else: - fmt1_str = f' {{0:d}} {{1:d}} {{2:{fmt}}}\n' - fmt2_str = ' {{0:d}} {{1:d}} {{2:{0}}} {{3:{0}}}\n'.format(fmt) + fmt1_str = f" {{0:d}} {{1:d}} {{2:{fmt}}}\n" + fmt2_str = " {{0:d}} {{1:d}} {{2:{0}}} {{3:{0}}}\n".format(fmt) # We currently force the model to be finalized # before we can write it @@ -243,19 +243,20 @@ def write_hamiltonian(self, H, hermitian=True, **kwargs): # If the model is Hermitian we can # do with writing out half the entries if hermitian: - herm_acc = kwargs.get('herm_acc', 1e-6) + herm_acc = kwargs.get("herm_acc", 1e-6) # We check whether it is Hermitian (not S) for i, isc in enumerate(geom.lattice.sc_off): oi = i * geom.no oj = geom.sc_index(-isc) * geom.no # get the difference between the ^\dagger elements - diff = h[:, oi:oi + geom.no] - \ - h[:, oj:oj + geom.no].transpose() + diff = h[:, oi : oi + geom.no] - h[:, oj : oj + geom.no].transpose() diff.eliminate_zeros() if np.any(np.abs(diff.data) > herm_acc): amax = np.amax(np.abs(diff.data)) - warn(f"The model could not be asserted to be Hermitian " - "within the accuracy required ({amax}).") + warn( + f"The model could not be asserted to be Hermitian " + "within the accuracy required ({amax})." + ) hermitian = False del diff @@ -271,23 +272,23 @@ def write_hamiltonian(self, H, hermitian=True, **kwargs): # Therefore we do it on a row basis, to limit memory # requirements for j in range(geom.no): - h[j, o:o + geom.no] = 0. + h[j, o : o + geom.no] = 0.0 h.eliminate_zeros() if not H.orthogonal: - S[j, o:o + geom.no] = 0. + S[j, o : o + geom.no] = 0.0 S.eliminate_zeros() o = geom.sc_index(np.zeros([3], np.int32)) # Get upper-triangular matrix of the unit-cell H and S - ut = triu(h[:, o:o + geom.no], k=0).tocsr() + ut = triu(h[:, o : o + geom.no], k=0).tocsr() for j in range(geom.no): - h[j, o:o + geom.no] = 0. - h[j, o:o + geom.no] = ut[j, :] + h[j, o : o + geom.no] = 0.0 + h[j, o : o + geom.no] = ut[j, :] h.eliminate_zeros() if not H.orthogonal: - ut = triu(S[:, o:o + geom.no], k=0).tocsr() + ut = triu(S[:, o : o + geom.no], k=0).tocsr() for j in range(geom.no): - S[j, o:o + geom.no] = 0. - S[j, o:o + geom.no] = ut[j, :] + S[j, o : o + geom.no] = 0.0 + S[j, o : o + geom.no] = ut[j, :] S.eliminate_zeros() # Ensure that S and H have the same sparsity pattern @@ -301,13 +302,13 @@ def write_hamiltonian(self, H, hermitian=True, **kwargs): for i, isc in enumerate(geom.lattice.sc_off): # Check that we have any contributions in this # sub-section - Hsub = h[:, i * geom.no:(i + 1) * geom.no] + Hsub = h[:, i * geom.no : (i + 1) * geom.no] if not H.orthogonal: - Ssub = S[:, i * geom.no:(i + 1) * geom.no] + Ssub = S[:, i * geom.no : (i + 1) * geom.no] if Hsub.getnnz() == 0: continue # We have a contribution, write out the information - self._write('\nbegin matrix {:d} {:d} {:d}\n'.format(*isc)) + self._write("\nbegin matrix {:d} {:d} {:d}\n".format(*isc)) if advanced: for jo, io, hh in ispmatrixd(Hsub): o = np.array([jo, io], np.int32) @@ -316,28 +317,26 @@ def write_hamiltonian(self, H, hermitian=True, **kwargs): if not H.orthogonal: s = Ssub[jo, io] elif jo == io: - s = 1. + s = 1.0 else: - s = 0. - if s == 0.: + s = 0.0 + if s == 0.0: self._write(fmt1_str.format(a[0], o[0], a[1], o[1], hh)) else: - self._write( - fmt2_str.format( - a[0], o[0], a[1], o[1], hh, s)) + self._write(fmt2_str.format(a[0], o[0], a[1], o[1], hh, s)) else: for jo, io, hh in ispmatrixd(Hsub): if not H.orthogonal: s = Ssub[jo, io] elif jo == io: - s = 1. + s = 1.0 else: - s = 0. - if s == 0.: + s = 0.0 + if s == 0.0: self._write(fmt1_str.format(jo, io, hh)) else: self._write(fmt2_str.format(jo, io, hh, s)) - self._write('end matrix {:d} {:d} {:d}\n'.format(*isc)) + self._write("end matrix {:d} {:d} {:d}\n".format(*isc)) -add_sile('ham', hamiltonianSile, case=False, gzip=True) +add_sile("ham", hamiltonianSile, case=False, gzip=True) diff --git a/src/sisl/io/molden.py b/src/sisl/io/molden.py index 2149214e27..dfdea26061 100644 --- a/src/sisl/io/molden.py +++ b/src/sisl/io/molden.py @@ -8,28 +8,30 @@ from .sile import * -__all__ = ['moldenSile'] +__all__ = ["moldenSile"] @set_module("sisl.io") class moldenSile(Sile): - """ Molden file object """ + """Molden file object""" @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def write_lattice(self, lattice): - """ Writes the supercell to the contained file """ + """Writes the supercell to the contained file""" # Check that we can write to the file sile_raise_write(self) # Write the number of atoms in the geometry - self._write('[Molden Format]\n') + self._write("[Molden Format]\n") # Sadly, MOLDEN does not read this information... @sile_fh_open() - def write_geometry(self, geometry, fmt='.8f'): - """ Writes the geometry to the contained file """ + def write_geometry(self, geometry, fmt=".8f"): + """Writes the geometry to the contained file""" # Check that we can write to the file sile_raise_write(self) @@ -37,21 +39,23 @@ def write_geometry(self, geometry, fmt='.8f'): self.write_lattice(geometry.lattice) # Write in ATOM mode - self._write('[Atoms] Angs\n') + self._write("[Atoms] Angs\n") # Write out the cell information in the comment field # This contains the cell vectors in a single vector (3 + 3 + 3) # quantities, plus the number of supercells (3 ints) - fmt_str = '{{0:2s}} {{1:4d}} {{2:4d}} {{3:{0}}} {{4:{0}}} {{5:{0}}}\n'.format(fmt) + fmt_str = ( + "{{0:2s}} {{1:4d}} {{2:4d}} {{3:{0}}} {{4:{0}}} {{5:{0}}}\n".format(fmt) + ) for ia, a, _ in geometry.iter_species(): self._write(fmt_str.format(a.symbol, ia, a.Z, *geometry.xyz[ia, :])) def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('molf', moldenSile, case=False, gzip=True) +add_sile("molf", moldenSile, case=False, gzip=True) diff --git a/src/sisl/io/openmx/__init__.py b/src/sisl/io/openmx/__init__.py index d1a7130a04..7bddb9cb19 100644 --- a/src/sisl/io/openmx/__init__.py +++ b/src/sisl/io/openmx/__init__.py @@ -8,6 +8,6 @@ omxSileOpenMX - input file mdSileOpenMX """ -from .sile import * # isort: split +from .sile import * # isort: split from .md import * from .omx import * diff --git a/src/sisl/io/openmx/md.py b/src/sisl/io/openmx/md.py index b3898c4c7d..a871562bf0 100644 --- a/src/sisl/io/openmx/md.py +++ b/src/sisl/io/openmx/md.py @@ -18,4 +18,4 @@ class mdSileOpenMX(xyzSile, SileOpenMX): pass -add_sile('md', mdSileOpenMX, gzip=True) +add_sile("md", mdSileOpenMX, gzip=True) diff --git a/src/sisl/io/openmx/omx.py b/src/sisl/io/openmx/omx.py index 0b47f84544..a9957ff78a 100644 --- a/src/sisl/io/openmx/omx.py +++ b/src/sisl/io/openmx/omx.py @@ -12,15 +12,15 @@ from ..sile import * from .sile import SileOpenMX -__all__ = ['omxSileOpenMX'] +__all__ = ["omxSileOpenMX"] -_LOGICAL_TRUE = ['on', 'yes', 'true', '.true.', 'ok'] -_LOGICAL_FALSE = ['off', 'no', 'false', '.false.', 'ng'] +_LOGICAL_TRUE = ["on", "yes", "true", ".true.", "ok"] +_LOGICAL_FALSE = ["off", "no", "false", ".false.", "ng"] _LOGICAL = _LOGICAL_FALSE + _LOGICAL_TRUE class omxSileOpenMX(SileOpenMX): - r""" OpenMX-input file + r"""OpenMX-input file By supplying base you can reference files in other directories. By default the ``base`` is the directory given in the file name. @@ -51,14 +51,14 @@ class omxSileOpenMX(SileOpenMX): @property def file(self): - """ Return the current file name (without the directory prefix) """ + """Return the current file name (without the directory prefix)""" return self._file def _setup(self, *args, **kwargs): - """ Setup the `omxSileOpenMX` after initialization """ + """Setup the `omxSileOpenMX` after initialization""" super()._setup(*args, **kwargs) # These are the comments - self._comment = ['#'] + self._comment = ["#"] # List of parent file-handles used while reading self._parent_fh = [] @@ -68,7 +68,9 @@ def _pushfile(self, f): self._parent_fh.append(self.fh) self.fh = self.dir_file(f).open(self._mode) else: - warn(f'{self!s} is trying to include file: {f} but the file seems not to exist? Will disregard file!') + warn( + f"{self!s} is trying to include file: {f} but the file seems not to exist? Will disregard file!" + ) def _popfile(self): if len(self._parent_fh) > 0: @@ -78,7 +80,7 @@ def _popfile(self): return False def _seek(self): - """ Closes all files, and starts over from beginning """ + """Closes all files, and starts over from beginning""" try: while self._popfile(): pass @@ -88,7 +90,7 @@ def _seek(self): @sile_fh_open() def _read_key(self, key): - """ Try and read the first occurence of a key + """Try and read the first occurence of a key This will take care of blocks, labels and piped in labels @@ -98,8 +100,10 @@ def _read_key(self, key): key to find in the file """ self._seek() + def tokey(key): return key.lower() + keyl = tokey(key) def valid_line(line): @@ -120,12 +124,12 @@ def process_line(line): # The last case is if the key is the first word on the line # In that case we have found what we are looking for if lsl[0] == keyl: - return (' '.join(ls[1:])).strip() + return (" ".join(ls[1:])).strip() - elif lsl[0].startswith('<'): + elif lsl[0].startswith("<"): # Get key lsl_key = lsl[0][1:] - lsl_end = lsl_key + '>' + lsl_end = lsl_key + ">" if lsl_key == keyl: # Read in the block content lines = [] @@ -141,12 +145,12 @@ def process_line(line): return None # Perform actual reading of line - l = self.readline().split('#')[0] + l = self.readline().split("#")[0] if len(l) == 0: return None l = process_line(l) while l is None: - l = self.readline().split('#')[0] + l = self.readline().split("#")[0] if len(l) == 0: if not self._popfile(): return None @@ -156,7 +160,7 @@ def process_line(line): @classmethod def _type(cls, value): - """ Determine the type by the value + """Determine the type by the value Parameters ---------- @@ -168,7 +172,7 @@ def _type(cls, value): if isinstance(value, list): # A block, <[name] - return 'B' + return "B" # Grab the entire line (beside the key) values = value.split() @@ -176,23 +180,23 @@ def _type(cls, value): val = values[0].lower() if val in _LOGICAL: # logical - return 'l' + return "l" try: float(val) - if '.' in val: + if "." in val: # a real number (otherwise an integer) - return 'r' - return 'i' + return "r" + return "i" except Exception: pass # fall-back to name with everything - return 'n' + return "n" @sile_fh_open() def type(self, label): - """ Return the type of the fdf-keyword + """Return the type of the fdf-keyword Parameters ---------- @@ -204,7 +208,7 @@ def type(self, label): @sile_fh_open() def get(self, key, default=None): - """ Retrieve keyword from the file + """Retrieve keyword from the file Parameters ---------- @@ -231,25 +235,25 @@ def get(self, key, default=None): # We will only do something if it is a real, int, or physical. # Else we simply return, as-is - if t == 'r': + if t == "r": if default is None: return float(value) t = type(default) return t(value) - elif t == 'i': + elif t == "i": if default is None: return int(value) t = type(default) return t(value) - elif t == 'l': + elif t == "l": return value.lower() in _LOGICAL_TRUE return value def read_basis(self, *args, **kwargs): - """ Reads basis + """Reads basis Parameters ---------- @@ -260,16 +264,16 @@ def read_basis(self, *args, **kwargs): the order of which to try and read the lattice If `order` is present `output` is disregarded. """ - order = kwargs.pop('order', ['dat', 'omx']) + order = kwargs.pop("order", ["dat", "omx"]) for f in order: - v = getattr(self, '_r_basis_{}'.format(f.lower()))(*args, **kwargs) + v = getattr(self, "_r_basis_{}".format(f.lower()))(*args, **kwargs) if v is not None: return v return None def _r_basis_omx(self): - ns = self.get('Species.Number', 0) - data = self.get('Definition.of.Atomic.Species') + ns = self.get("Species.Number", 0) + data = self.get("Definition.of.Atomic.Species") if data is None: return None @@ -283,11 +287,11 @@ def rf_func(R): f = np.ones(500) f[r > R] = 0 return r, f - return np.linspace(0, 1., 10), np.zeros(10) + return np.linspace(0, 1.0, 10), np.zeros(10) def decompose_basis(l): # Only split once - Zr, spec = l.split('-', 1) + Zr, spec = l.split("-", 1) idx = 0 for i, c in enumerate(Zr): if c.isdigit(): @@ -307,21 +311,33 @@ def decompose_basis(l): orbs = [] m_order = { 0: [0], - 1: [1, -1, 0], # px, py, pz - 2: [0, 2, -2, 1, -1], # d3z^2-r^2, dx^2-y^2, dxy, dxz, dyz - 3: [0, 1, -1, 2, -2, 3, -3], # f5z^2-3r^2, f5xz^2-xr^2, f5yz^2-yr^2, fzx^2-zy^2, fxyz, fx^3-3*xy^2, f3yx^2-y^3 + 1: [1, -1, 0], # px, py, pz + 2: [0, 2, -2, 1, -1], # d3z^2-r^2, dx^2-y^2, dxy, dxz, dyz + 3: [ + 0, + 1, + -1, + 2, + -2, + 3, + -3, + ], # f5z^2-3r^2, f5xz^2-xr^2, f5yz^2-yr^2, fzx^2-zy^2, fxyz, fx^3-3*xy^2, f3yx^2-y^3 4: [0, 1, -1, 2, -2, 3, -3, 4, -4], - 5: [0, 1, -1, 2, -2, 3, -3, 4, -4, 5, -5] + 5: [0, 1, -1, 2, -2, 3, -3, 4, -4, 5, -5], } for i, c in enumerate(spec): try: l = "spdfgh".index(c) try: - nZ = int(spec[i+1]) + nZ = int(spec[i + 1]) except Exception: nZ = 1 for z in range(nZ): - orbs.extend(SphericalOrbital(l, rf_func(R)).toAtomicOrbital(m=m_order[l], zeta=z+1)) + orbs.extend( + SphericalOrbital(l, rf_func(R)).toAtomicOrbital( + m=m_order[l], zeta=z + 1 + ) + ) except Exception: pass @@ -337,7 +353,7 @@ def decompose_basis(l): return atom def read_lattice(self, output=False, *args, **kwargs): - """ Reads lattice + """Reads lattice One can limit the tried files to only one file by passing only a single file ending. @@ -352,32 +368,32 @@ def read_lattice(self, output=False, *args, **kwargs): If `order` is present `output` is disregarded. """ if output: - order = kwargs.pop('order', ['dat', 'omx']) + order = kwargs.pop("order", ["dat", "omx"]) else: - order = kwargs.pop('order', ['dat', 'omx']) + order = kwargs.pop("order", ["dat", "omx"]) for f in order: - v = getattr(self, '_r_lattice_{}'.format(f.lower()))(*args, **kwargs) + v = getattr(self, "_r_lattice_{}".format(f.lower()))(*args, **kwargs) if v is not None: return v return None def _r_lattice_omx(self, *args, **kwargs): - """ Returns `Lattice` object from the omx file """ - conv = self.get('Atoms.UnitVectors.Unit', default='Ang') - if conv.upper() == 'AU': - conv = units('Bohr', 'Ang') + """Returns `Lattice` object from the omx file""" + conv = self.get("Atoms.UnitVectors.Unit", default="Ang") + if conv.upper() == "AU": + conv = units("Bohr", "Ang") else: - conv = 1. + conv = 1.0 # Read in cell cell = np.empty([3, 3], np.float64) - lc = self.get('Atoms.UnitVectors') + lc = self.get("Atoms.UnitVectors") if not lc is None: for i in range(3): cell[i, :] = [float(k) for k in lc[i].split()[:3]] else: - raise SileError('Could not find Atoms.UnitVectors in file') + raise SileError("Could not find Atoms.UnitVectors in file") cell *= conv return Lattice(cell) @@ -385,7 +401,7 @@ def _r_lattice_omx(self, *args, **kwargs): _r_lattice_dat = _r_lattice_omx def read_geometry(self, output=False, *args, **kwargs): - """ Returns Geometry object + """Returns Geometry object One can limit the tried files to only one file by passing only a single file ending. @@ -400,24 +416,24 @@ def read_geometry(self, output=False, *args, **kwargs): If `order` is present `output` is disregarded. """ if output: - order = kwargs.pop('order', ['dat', 'omx']) + order = kwargs.pop("order", ["dat", "omx"]) else: - order = kwargs.pop('order', ['dat', 'omx']) + order = kwargs.pop("order", ["dat", "omx"]) for f in order: - v = getattr(self, '_r_geometry_{}'.format(f.lower()))(*args, **kwargs) + v = getattr(self, "_r_geometry_{}".format(f.lower()))(*args, **kwargs) if v is not None: return v return None def _r_geometry_omx(self, *args, **kwargs): - """ Returns `Geometry` """ - lattice = self.read_lattice(order=['omx']) + """Returns `Geometry`""" + lattice = self.read_lattice(order=["omx"]) - na = self.get('Atoms.Number', default=0) - conv = self.get('Atoms.SpeciesAndCoordinates.Unit', default='Ang') - data = self.get('Atoms.SpeciesAndCoordinates') + na = self.get("Atoms.Number", default=0) + conv = self.get("Atoms.SpeciesAndCoordinates.Unit", default="Ang") + data = self.get("Atoms.SpeciesAndCoordinates") if data is None: - raise SislError('Cannot find key: Atoms.SpeciesAndCoordinates') + raise SislError("Cannot find key: Atoms.SpeciesAndCoordinates") if na == 0: # Default to the size of the labels @@ -426,14 +442,15 @@ def _r_geometry_omx(self, *args, **kwargs): # Reduce to the number of atoms. data = data[:na] - atoms = self.read_basis(order=['omx']) + atoms = self.read_basis(order=["omx"]) + def find_atom(tag): if atoms is None: return Atom(tag) for atom in atoms: if atom.tag == tag: return atom - raise SislError(f'Error when reading the basis for atomic tag: {tag}.') + raise SislError(f"Error when reading the basis for atomic tag: {tag}.") xyz = [] atom = [] @@ -443,9 +460,9 @@ def find_atom(tag): xyz.append(list(map(float, dat.split()[2:5]))) xyz = _a.arrayd(xyz) - if conv == 'AU': - xyz *= units('Bohr', 'Ang') - elif conv == 'FRAC': + if conv == "AU": + xyz *= units("Bohr", "Ang") + elif conv == "FRAC": xyz = np.dot(xyz, lattice.cell) return Geometry(xyz, atoms=atom, lattice=lattice) @@ -453,4 +470,4 @@ def find_atom(tag): _r_geometry_dat = _r_geometry_omx -add_sile('omx', omxSileOpenMX, case=False, gzip=True) +add_sile("omx", omxSileOpenMX, case=False, gzip=True) diff --git a/src/sisl/io/openmx/sile.py b/src/sisl/io/openmx/sile.py index 6f1bbff65c..7eb777f707 100644 --- a/src/sisl/io/openmx/sile.py +++ b/src/sisl/io/openmx/sile.py @@ -5,7 +5,7 @@ from ..sile import Sile, SileBin, SileCDF -__all__ = ['SileOpenMX', 'SileCDFOpenMX', 'SileBinOpenMX'] +__all__ = ["SileOpenMX", "SileCDFOpenMX", "SileBinOpenMX"] @set_module("sisl.io.openmx") diff --git a/src/sisl/io/orca/__init__.py b/src/sisl/io/orca/__init__.py index 6d9f837f36..d5e4d5730d 100644 --- a/src/sisl/io/orca/__init__.py +++ b/src/sisl/io/orca/__init__.py @@ -8,6 +8,6 @@ stdoutSileORCA - ORCA output file txtSileORCA - ORCA property.txt file """ -from .sile import * # isort: split +from .sile import * # isort: split from .stdout import * from .txt import * diff --git a/src/sisl/io/orca/sile.py b/src/sisl/io/orca/sile.py index 06d15c42bd..d7a37e03fb 100644 --- a/src/sisl/io/orca/sile.py +++ b/src/sisl/io/orca/sile.py @@ -5,7 +5,7 @@ from ..sile import Sile, SileBin -__all__ = ['SileORCA', 'SileBinORCA'] +__all__ = ["SileORCA", "SileBinORCA"] @set_module("sisl.io.orca") diff --git a/src/sisl/io/orca/stdout.py b/src/sisl/io/orca/stdout.py index 7c3e27e944..1ab40859f6 100644 --- a/src/sisl/io/orca/stdout.py +++ b/src/sisl/io/orca/stdout.py @@ -20,39 +20,57 @@ @set_module("sisl.io.orca") class stdoutSileORCA(SileORCA): - """ Output file from ORCA """ + """Output file from ORCA""" _info_attributes_ = [ - _A("na", r".*Number of atoms", - lambda attr, match: int(match.string.split()[-1])), - _A("no", r".*Number of basis functions", - lambda attr, match: int(match.string.split()[-1])), - _A("vdw_correction", r".*DFT DISPERSION CORRECTION", - lambda attr, match: True, default=False), - _A("completed", r".*ORCA TERMINATED NORMALLY", - lambda attr, match: True, default=False), + _A( + "na", + r".*Number of atoms", + lambda attr, match: int(match.string.split()[-1]), + ), + _A( + "no", + r".*Number of basis functions", + lambda attr, match: int(match.string.split()[-1]), + ), + _A( + "vdw_correction", + r".*DFT DISPERSION CORRECTION", + lambda attr, match: True, + default=False, + ), + _A( + "completed", + r".*ORCA TERMINATED NORMALLY", + lambda attr, match: True, + default=False, + ), ] def completed(self): - """ True if the full file has been read and "ORCA TERMINATED NORMALLY" was found. """ + """True if the full file has been read and "ORCA TERMINATED NORMALLY" was found.""" return self.info.completed @property - @deprecation("stdoutSileORCA.na is deprecated in favor of stdoutSileORCA.info.na", "0.16.0") + @deprecation( + "stdoutSileORCA.na is deprecated in favor of stdoutSileORCA.info.na", "0.16.0" + ) def na(self): - """ Number of atoms """ + """Number of atoms""" return self.info.na @property - @deprecation("stdoutSileORCA.no is deprecated in favor of stdoutSileORCA.info.no", "0.16.0") + @deprecation( + "stdoutSileORCA.no is deprecated in favor of stdoutSileORCA.info.no", "0.16.0" + ) def no(self): - """ Number of orbitals (basis functions) """ + """Number of orbitals (basis functions)""" return self.info.no @SileBinder(postprocess=np.array) @sile_fh_open() def read_electrons(self): - """ Read number of electrons (alpha, beta) + """Read number of electrons (alpha, beta) Returns ------- @@ -68,10 +86,15 @@ def read_electrons(self): @SileBinder() @sile_fh_open() - def read_charge(self, name='mulliken', projection='orbital', - orbitals=None, - reduced=True, spin=False): - """ Reads from charge (or spin) population analysis + def read_charge( + self, + name="mulliken", + projection="orbital", + orbitals=None, + reduced=True, + spin=False, + ): + """Reads from charge (or spin) population analysis Parameters ---------- @@ -90,24 +113,24 @@ def read_charge(self, name='mulliken', projection='orbital', ------- PropertyDicts or ndarray or list thereof: atom/orbital-resolved charge (or spin) data """ - if name.lower() in ('mulliken', 'm'): - name = 'mulliken' - elif name.lower() in ('loewdin', 'lowdin', 'löwdin', 'l'): - name = 'loewdin' + if name.lower() in ("mulliken", "m"): + name = "mulliken" + elif name.lower() in ("loewdin", "lowdin", "löwdin", "l"): + name = "loewdin" else: raise NotImplementedError(f"name={name} is not implemented") - if projection.lower() in ('atom', 'atoms', 'a'): - projection = 'atom' - elif projection.lower() in ('orbital', 'orbitals', 'orb', 'o'): - projection = 'orbital' + if projection.lower() in ("atom", "atoms", "a"): + projection = "atom" + elif projection.lower() in ("orbital", "orbitals", "orb", "o"): + projection = "orbital" else: raise ValueError(f"Projection must be atom or orbital") - if projection == 'atom': - if name == 'mulliken': + if projection == "atom": + if name == "mulliken": step_to = "MULLIKEN ATOMIC CHARGES" - elif name == 'loewdin': + elif name == "loewdin": step_to = "LOEWDIN ATOMIC CHARGES" def read_block(step_to): @@ -119,7 +142,7 @@ def read_block(step_to): else: spin_block = False - self.readline() # skip --- + self.readline() # skip --- A = np.empty(self.info.na, np.float64) for ia in range(self.info.na): line = self.readline() @@ -132,10 +155,10 @@ def read_block(step_to): A[ia] = v[-1] return A - elif projection == 'orbital' and reduced: - if name == 'mulliken': + elif projection == "orbital" and reduced: + if name == "mulliken": step_to = "MULLIKEN REDUCED ORBITAL CHARGES" - elif name == 'loewdin': + elif name == "loewdin": step_to = "LOEWDIN REDUCED ORBITAL CHARGES" def read_reduced_orbital_block(): @@ -169,7 +192,7 @@ def read_block(step_to): elif spin_block: self.step_to("CHARGE", allow_reread=False) elif not spin: - self.readline() # skip --- + self.readline() # skip --- else: return None @@ -183,10 +206,10 @@ def read_block(step_to): Da[ia] = d return Da - elif projection == 'orbital' and not reduced: - if name == 'mulliken': + elif projection == "orbital" and not reduced: + if name == "mulliken": step_to = "MULLIKEN ORBITAL CHARGES" - elif name == 'loewdin': + elif name == "loewdin": step_to = "LOEWDIN ORBITAL CHARGES" def read_block(step_to): @@ -199,17 +222,17 @@ def read_block(step_to): if not f: return None - self.readline() # skip --- + self.readline() # skip --- if "MULLIKEN" in step_to: - self.readline() # skip line "The uncorrected..." + self.readline() # skip line "The uncorrected..." - Do = np.empty(self.info.no, np.float64) # orbital-resolved - Da = np.zeros(self.info.na, np.float64) # atom-resolved + Do = np.empty(self.info.no, np.float64) # orbital-resolved + Da = np.zeros(self.info.na, np.float64) # atom-resolved for io in range(self.info.no): - v = self.readline().split() # io, ia+element, orb, chg, (spin) + v = self.readline().split() # io, ia+element, orb, chg, (spin) # split atom number and element from v[1] - ia, element = '', '' + ia, element = "", "" for s in v[1]: if s.isdigit(): ia += s @@ -234,7 +257,7 @@ def read_block(step_to): @SileBinder() @sile_fh_open() def read_energy(self): - """ Reads the energy blocks + """Reads the energy blocks Returns ------- @@ -244,10 +267,10 @@ def read_energy(self): if not f: return None - self.readline() # skip --- - self.readline() # skip blank line + self.readline() # skip --- + self.readline() # skip blank line - Ha2eV = units('Ha', 'eV') + Ha2eV = units("Ha", "eV") E = PropertyDict() line = self.readline() @@ -275,7 +298,7 @@ def read_energy(self): @SileBinder() @sile_fh_open() def read_orbital_energies(self): - """ Reads the "ORBITAL ENERGIES" blocks + """Reads the "ORBITAL ENERGIES" blocks Returns ------- @@ -285,7 +308,7 @@ def read_orbital_energies(self): if not f: return None - self.readline() # skip --- + self.readline() # skip --- if "SPIN UP ORBITALS" in self.readline(): spin = True E = np.empty([self.info.no, 2], np.float64) @@ -293,7 +316,7 @@ def read_orbital_energies(self): spin = False E = np.empty([self.info.no, 1], np.float64) - self.readline() # Skip "NO OCC" header line + self.readline() # Skip "NO OCC" header line v = self.readline().split() while len(v) > 0: @@ -304,17 +327,19 @@ def read_orbital_energies(self): if not spin: return E.ravel() - self.readline() # skip "SPIN DOWN ORBITALS" - self.readline() # Skip "NO OCC" header line + self.readline() # skip "SPIN DOWN ORBITALS" + self.readline() # Skip "NO OCC" header line v = self.readline().split() - while len(v) > 0 and '---' not in v[0]: + while len(v) > 0 and "---" not in v[0]: i = int(v[0]) E[i, 1] = v[-1] v = self.readline().split() return E -outputSileORCA = deprecation("outputSileORCA has been deprecated in favor of stdoutSileOrca.", "0.15")(stdoutSileORCA) +outputSileORCA = deprecation( + "outputSileORCA has been deprecated in favor of stdoutSileOrca.", "0.15" +)(stdoutSileORCA) add_sile("output", stdoutSileORCA, gzip=True, case=False) add_sile("orca.out", stdoutSileORCA, gzip=True, case=False) diff --git a/src/sisl/io/orca/tests/test_stdout.py b/src/sisl/io/orca/tests/test_stdout.py index d40230744d..7b226a063f 100644 --- a/src/sisl/io/orca/tests/test_stdout.py +++ b/src/sisl/io/orca/tests/test_stdout.py @@ -11,18 +11,19 @@ from sisl.io.orca.stdout import * pytestmark = [pytest.mark.io, pytest.mark.orca] -_dir = osp.join('sisl', 'io', 'orca') +_dir = osp.join("sisl", "io", "orca") def test_tags(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) assert out.info.na == 2 assert out.info.no == 62 assert out.completed() + def test_read_electrons(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) N = out.read_electrons[:]() @@ -33,18 +34,20 @@ def test_read_electrons(sisl_files): assert N[0] == 7.999998537734 assert N[1] == 6.999998987209 + def test_charge_name(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) - for name in ['mulliken', 'MULLIKEN', 'loewdin', 'Lowdin', 'LÖWDIN']: + for name in ["mulliken", "MULLIKEN", "loewdin", "Lowdin", "LÖWDIN"]: assert out.read_charge(name=name) is not None + def test_charge_mulliken_atom(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='mulliken', projection='atom') - S = out.read_charge[:](name='mulliken', projection='atom', spin=True) + C = out.read_charge[:](name="mulliken", projection="atom") + S = out.read_charge[:](name="mulliken", projection="atom", spin=True) assert len(C) == 2 assert C[0][0] == 0.029160 assert S[0][0] == 0.687779 @@ -55,19 +58,20 @@ def test_charge_mulliken_atom(sisl_files): assert C[1][1] == -0.029158 assert S[1][1] == 0.312207 - C = out.read_charge[-1](name='mulliken', projection='atom') - S = out.read_charge[-1](name='mulliken', projection='atom', spin=True) + C = out.read_charge[-1](name="mulliken", projection="atom") + S = out.read_charge[-1](name="mulliken", projection="atom", spin=True) assert C[0] == 0.029158 assert S[0] == 0.687793 assert C[1] == -0.029158 assert S[1] == 0.312207 + def test_lowedin_atom(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='loewdin', projection='atom') - S = out.read_charge[:](name='loewdin', projection='atom', spin=True) + C = out.read_charge[:](name="loewdin", projection="atom") + S = out.read_charge[:](name="loewdin", projection="atom", spin=True) assert len(C) == 2 assert C[0][0] == -0.111221 assert S[0][0] == 0.660316 @@ -78,89 +82,98 @@ def test_lowedin_atom(sisl_files): assert C[1][1] == 0.111223 assert S[1][1] == 0.339673 - C = out.read_charge[-1](name='loewdin', projection='atom') - S = out.read_charge[-1](name='loewdin', projection='atom', spin=True) + C = out.read_charge[-1](name="loewdin", projection="atom") + S = out.read_charge[-1](name="loewdin", projection="atom", spin=True) assert C[0] == -0.111223 assert S[0] == 0.660327 assert C[1] == 0.111223 assert S[1] == 0.339673 + def test_charge_mulliken_reduced(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='mulliken', projection='orbital') - S = out.read_charge[:](name='mulliken', projection='orbital', spin=True) + C = out.read_charge[:](name="mulliken", projection="orbital") + S = out.read_charge[:](name="mulliken", projection="orbital", spin=True) assert len(C) == 2 # first charge block - assert C[0][(0, 's')] == 3.915850 - assert C[0][(0, 'pz')] == 0.710261 - assert C[0][(1, 'dz2')] == 0.004147 - assert C[0][(1, 'p')] == 4.116068 + assert C[0][(0, "s")] == 3.915850 + assert C[0][(0, "pz")] == 0.710261 + assert C[0][(1, "dz2")] == 0.004147 + assert C[0][(1, "p")] == 4.116068 # first spin block - assert S[0][(0, 'dx2y2')] == 0.001163 - assert S[0][(1, 'f+2')] == -0.000122 + assert S[0][(0, "dx2y2")] == 0.001163 + assert S[0][(1, "f+2")] == -0.000122 # last charge block - assert C[1][(0, 'pz')] == 0.710263 - assert C[1][(0, 'f0')] == 0.000681 - assert C[1][(1, 's')] == 3.860487 + assert C[1][(0, "pz")] == 0.710263 + assert C[1][(0, "f0")] == 0.000681 + assert C[1][(1, "s")] == 3.860487 # last spin block - assert S[1][(0, 'p')] == 0.685743 - assert S[1][(1, 'dz2')] == -0.000163 + assert S[1][(0, "p")] == 0.685743 + assert S[1][(1, "dz2")] == -0.000163 - C = out.read_charge[-1](name='mulliken', projection='orbital') - S = out.read_charge[-1](name='mulliken', projection='orbital', spin=True) + C = out.read_charge[-1](name="mulliken", projection="orbital") + S = out.read_charge[-1](name="mulliken", projection="orbital", spin=True) # last charge block - assert C[(0, 'pz')] == 0.710263 - assert C[(0, 'f0')] == 0.000681 - assert C[(1, 's')] == 3.860487 + assert C[(0, "pz")] == 0.710263 + assert C[(0, "f0")] == 0.000681 + assert C[(1, "s")] == 3.860487 # last spin block - assert S[(0, 'p')] == 0.685743 - assert S[(1, 'dz2')] == -0.000163 + assert S[(0, "p")] == 0.685743 + assert S[(1, "dz2")] == -0.000163 - C = out.read_charge[:](name='mulliken', projection='orbital', orbitals='pz') + C = out.read_charge[:](name="mulliken", projection="orbital", orbitals="pz") assert C[0][0] == 0.710261 - S = out.read_charge[:](name='mulliken', projection='orbital', orbitals='f+2', spin=True) + S = out.read_charge[:]( + name="mulliken", projection="orbital", orbitals="f+2", spin=True + ) assert S[0][1] == -0.000122 - S = out.read_charge[-1](name='mulliken', projection='orbital', orbitals='p', spin=True) + S = out.read_charge[-1]( + name="mulliken", projection="orbital", orbitals="p", spin=True + ) assert S[0] == 0.685743 + def test_charge_loewdin_reduced(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='loewdin', projection='orbital') - S = out.read_charge[:](name='loewdin', projection='orbital', spin=True) + C = out.read_charge[:](name="loewdin", projection="orbital") + S = out.read_charge[:](name="loewdin", projection="orbital", spin=True) assert len(S) == 2 - assert C[0][(0, 's')] == 3.553405 - assert C[0][(0, 'pz')] == 0.723111 - assert C[1][(0, 'pz')] == 0.723113 - assert S[1][(1, 'pz')] == -0.010829 - - C = out.read_charge[-1](name='loewdin', projection='orbital') - S = out.read_charge[-1](name='loewdin', projection='orbital', spin=True) - assert C[(0, 'f-3')] == 0.017486 - assert S[(1, 'pz')] == -0.010829 - assert C[(0, 'pz')] == 0.723113 - assert S[(1, 'pz')] == -0.010829 - - C = out.read_charge[:](name='loewdin', projection='orbital', orbitals='s') + assert C[0][(0, "s")] == 3.553405 + assert C[0][(0, "pz")] == 0.723111 + assert C[1][(0, "pz")] == 0.723113 + assert S[1][(1, "pz")] == -0.010829 + + C = out.read_charge[-1](name="loewdin", projection="orbital") + S = out.read_charge[-1](name="loewdin", projection="orbital", spin=True) + assert C[(0, "f-3")] == 0.017486 + assert S[(1, "pz")] == -0.010829 + assert C[(0, "pz")] == 0.723113 + assert S[(1, "pz")] == -0.010829 + + C = out.read_charge[:](name="loewdin", projection="orbital", orbitals="s") assert C[0][0] == 3.553405 - C = out.read_charge[-1](name='loewdin', projection='orbital', orbitals='f-3') + C = out.read_charge[-1](name="loewdin", projection="orbital", orbitals="f-3") assert C[0] == 0.017486 - C = out.read_charge[-1](name='loewdin', projection='orbital', orbitals='pz') + C = out.read_charge[-1](name="loewdin", projection="orbital", orbitals="pz") assert C[0] == 0.723113 + def test_charge_mulliken_full(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='mulliken', projection='orbital', reduced=False) - S = out.read_charge[:](name='mulliken', projection='orbital', reduced=False, spin=True) + C = out.read_charge[:](name="mulliken", projection="orbital", reduced=False) + S = out.read_charge[:]( + name="mulliken", projection="orbital", reduced=False, spin=True + ) assert len(C) == 2 assert C[0][0] == 0.821857 assert S[0][0] == -0.000020 @@ -169,17 +182,22 @@ def test_charge_mulliken_full(sisl_files): assert C[1][8] == 0.313072 assert S[1][8] == 0.006429 - C = out.read_charge[-1](name='mulliken', projection='orbital', reduced=False) - S = out.read_charge[-1](name='mulliken', projection='orbital', reduced=False, spin=True) + C = out.read_charge[-1](name="mulliken", projection="orbital", reduced=False) + S = out.read_charge[-1]( + name="mulliken", projection="orbital", reduced=False, spin=True + ) assert C[8] == 0.313072 assert S[8] == 0.006429 + def test_charge_loewdin_full(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='loewdin', projection='orbital', reduced=False) - S = out.read_charge[:](name='loewdin', projection='orbital', reduced=False, spin=True) + C = out.read_charge[:](name="loewdin", projection="orbital", reduced=False) + S = out.read_charge[:]( + name="loewdin", projection="orbital", reduced=False, spin=True + ) assert len(S) == 2 assert C[0][0] == 0.894846 assert S[0][0] == 0.000337 @@ -188,75 +206,87 @@ def test_charge_loewdin_full(sisl_files): assert C[1][8] == 0.312172 assert S[1][8] == 0.005159 - C = out.read_charge[-1](name='loewdin', projection='orbital', reduced=False) - S = out.read_charge[-1](name='loewdin', projection='orbital', reduced=False, spin=True) + C = out.read_charge[-1](name="loewdin", projection="orbital", reduced=False) + S = out.read_charge[-1]( + name="loewdin", projection="orbital", reduced=False, spin=True + ) assert C[8] == 0.312172 assert S[8] == 0.005159 + def test_charge_atom_unpol(sisl_files): - f = sisl_files(_dir, 'molecule2.output') + f = sisl_files(_dir, "molecule2.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='mulliken', projection='atom') - S = out.read_charge[:](name='mulliken', projection='atom', spin=True) + C = out.read_charge[:](name="mulliken", projection="atom") + S = out.read_charge[:](name="mulliken", projection="atom", spin=True) assert len(C) == 2 assert S is None assert C[0][0] == -0.037652 - C = out.read_charge[-1](name='mulliken', projection='atom') - S = out.read_charge[-1](name='mulliken', projection='atom', spin=True) + C = out.read_charge[-1](name="mulliken", projection="atom") + S = out.read_charge[-1](name="mulliken", projection="atom", spin=True) assert C[0] == -0.037652 assert S is None - C = out.read_charge[-1](name='loewdin', projection='atom') - S = out.read_charge[-1](name='loewdin', projection='atom', spin=True) + C = out.read_charge[-1](name="loewdin", projection="atom") + S = out.read_charge[-1](name="loewdin", projection="atom", spin=True) assert C[0] == -0.259865 assert S is None + def test_charge_orbital_reduced_unpol(sisl_files): - f = sisl_files(_dir, 'molecule2.output') + f = sisl_files(_dir, "molecule2.output") out = stdoutSileORCA(f) - C = out.read_charge[:](name='mulliken', projection='orbital') - S = out.read_charge[:](name='mulliken', projection='orbital', spin=True) + C = out.read_charge[:](name="mulliken", projection="orbital") + S = out.read_charge[:](name="mulliken", projection="orbital", spin=True) assert len(C) == 2 assert S is None assert C[0][(0, "py")] == 0.534313 assert C[1][(1, "px")] == 1.346363 - C = out.read_charge[-1](name='mulliken', projection='orbital') - S = out.read_charge[-1](name='mulliken', projection='orbital', spin=True) + C = out.read_charge[-1](name="mulliken", projection="orbital") + S = out.read_charge[-1](name="mulliken", projection="orbital", spin=True) assert C[(0, "px")] == 0.954436 assert S is None - C = out.read_charge[-1](name='mulliken', projection='orbital', orbitals='px') - S = out.read_charge[-1](name='mulliken', projection='orbital', orbitals='px', spin=True) + C = out.read_charge[-1](name="mulliken", projection="orbital", orbitals="px") + S = out.read_charge[-1]( + name="mulliken", projection="orbital", orbitals="px", spin=True + ) assert C[0] == 0.954436 assert S is None - C = out.read_charge[-1](name='loewdin', projection='orbital') - S = out.read_charge[-1](name='loewdin', projection='orbital', spin=True) + C = out.read_charge[-1](name="loewdin", projection="orbital") + S = out.read_charge[-1](name="loewdin", projection="orbital", spin=True) assert C[(0, "d")] == 0.315910 assert S is None - C = out.read_charge[-1](name='loewdin', projection='orbital', orbitals='d') - S = out.read_charge[-1](name='loewdin', projection='orbital', orbitals='d', spin=True) + C = out.read_charge[-1](name="loewdin", projection="orbital", orbitals="d") + S = out.read_charge[-1]( + name="loewdin", projection="orbital", orbitals="d", spin=True + ) assert C[0] == 0.315910 assert S is None + @pytest.mark.only def test_charge_orbital_full_unpol(sisl_files): - f = sisl_files(_dir, 'molecule2.output') + f = sisl_files(_dir, "molecule2.output") out = stdoutSileORCA(f) - C = out.read_charge[-1](name='mulliken', projection='orbital', reduced=False) - S = out.read_charge[-1](name='mulliken', projection='orbital', reduced=False, spin=True) + C = out.read_charge[-1](name="mulliken", projection="orbital", reduced=False) + S = out.read_charge[-1]( + name="mulliken", projection="orbital", reduced=False, spin=True + ) assert C is None assert S is None + @pytest.mark.only def test_read_energy(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) E = out.read_energy[:]() @@ -266,16 +296,18 @@ def test_read_energy(sisl_files): E = out.read_energy[-1]() assert pytest.approx(E.total) == -3532.4784695729268 + def test_read_energy_vdw(sisl_files): - f = sisl_files(_dir, 'molecule2.output') + f = sisl_files(_dir, "molecule2.output") out = stdoutSileORCA(f) E = out.read_energy[-1]() assert E.vdw != 0 assert pytest.approx(E.total) == -3081.2640328972802 + def test_read_orbital_energies(sisl_files): - f = sisl_files(_dir, 'molecule.output') + f = sisl_files(_dir, "molecule.output") out = stdoutSileORCA(f) E = out.read_orbital_energies[:]() @@ -292,8 +324,9 @@ def test_read_orbital_energies(sisl_files): assert E.shape == (out.info.no, 2) assert pytest.approx(E[61, 0]) == 1173.4259 + def test_read_orbital_energies_unpol(sisl_files): - f = sisl_files(_dir, 'molecule2.output') + f = sisl_files(_dir, "molecule2.output") out = stdoutSileORCA(f) E = out.read_orbital_energies[:]() @@ -307,8 +340,9 @@ def test_read_orbital_energies_unpol(sisl_files): assert pytest.approx(E[0]) == -513.0976 assert pytest.approx(E[61]) == 1171.5967 + def test_multiple_calls(sisl_files): - f = sisl_files(_dir, 'molecule2.output') + f = sisl_files(_dir, "molecule2.output") out = stdoutSileORCA(f) N = out.read_electrons[:]() @@ -320,7 +354,7 @@ def test_multiple_calls(sisl_files): E = out.read_energy[:]() assert len(E) == 2 - C = out.read_charge[:](name='mulliken', projection='atom') + C = out.read_charge[:](name="mulliken", projection="atom") assert len(C) == 2 N = out.read_electrons[:]() diff --git a/src/sisl/io/orca/tests/test_txt.py b/src/sisl/io/orca/tests/test_txt.py index 9dbf0b1cbf..66de7ae465 100644 --- a/src/sisl/io/orca/tests/test_txt.py +++ b/src/sisl/io/orca/tests/test_txt.py @@ -11,17 +11,18 @@ from sisl.io.orca.txt import * pytestmark = [pytest.mark.io, pytest.mark.orca] -_dir = osp.join('sisl', 'io', 'orca') +_dir = osp.join("sisl", "io", "orca") def test_tags(sisl_files): - f = sisl_files(_dir, 'molecule_property.txt') + f = sisl_files(_dir, "molecule_property.txt") out = txtSileORCA(f) assert out.info.na == 2 assert out.info.no == None + def test_read_electrons(sisl_files): - f = sisl_files(_dir, 'molecule_property.txt') + f = sisl_files(_dir, "molecule_property.txt") out = txtSileORCA(f) N = out.read_electrons[:]() assert N[0, 0] == 7.9999985377 @@ -30,16 +31,18 @@ def test_read_electrons(sisl_files): assert N[0] == 7.9999985377 assert N[1] == 6.9999989872 + def test_read_energy(sisl_files): - f = sisl_files(_dir, 'molecule_property.txt') + f = sisl_files(_dir, "molecule_property.txt") out = txtSileORCA(f) E = out.read_energy[:]() assert len(E) == 2 E = out.read_energy[-1]() assert pytest.approx(E.total) == -3532.4797529097723 + def test_read_energy_vdw(sisl_files): - f = sisl_files(_dir, 'molecule2_property.txt') + f = sisl_files(_dir, "molecule2_property.txt") out = txtSileORCA(f) E = out.read_energy[:]() assert len(E) == 2 @@ -50,8 +53,9 @@ def test_read_energy_vdw(sisl_files): assert pytest.approx(E.total) == -3081.2651523149702 assert pytest.approx(E.vdw) == -0.011180550414138613 + def test_read_geometry(sisl_files): - f = sisl_files(_dir, 'molecule_property.txt') + f = sisl_files(_dir, "molecule_property.txt") out = txtSileORCA(f) G = out.read_geometry[:]() assert G[0].xyz[0, 0] == 0.421218019838 @@ -63,11 +67,12 @@ def test_read_geometry(sisl_files): assert G.xyz[1, 0] == 1.578781789721 assert G.xyz[0, 1] == 0.0 assert G.xyz[1, 1] == 0.0 - assert G.atoms[0].tag == 'N' - assert G.atoms[1].tag == 'O' + assert G.atoms[0].tag == "N" + assert G.atoms[1].tag == "O" + def test_multiple_calls(sisl_files): - f = sisl_files(_dir, 'molecule_property.txt') + f = sisl_files(_dir, "molecule_property.txt") out = txtSileORCA(f) N = out.read_electrons[:]() assert len(N) == 2 diff --git a/src/sisl/io/orca/txt.py b/src/sisl/io/orca/txt.py index ff76a2eaf0..a922feb48a 100644 --- a/src/sisl/io/orca/txt.py +++ b/src/sisl/io/orca/txt.py @@ -13,40 +13,54 @@ from ..sile import add_sile, sile_fh_open from .sile import SileORCA -__all__ = ['txtSileORCA'] +__all__ = ["txtSileORCA"] _A = SileORCA.InfoAttr @set_module("sisl.io.orca") class txtSileORCA(SileORCA): - """ Output from the ORCA property.txt file """ + """Output from the ORCA property.txt file""" _info_attributes_ = [ - _A("na", r".*Number of atoms:", - lambda attr, match: int(match.string.split()[-1])), - _A("no", r".*Number of basis functions:", - lambda attr, match: int(match.string.split()[-1])), - _A("vdw_correction", r".*\$ VdW_Correction", - lambda attr, match: True, default=False), + _A( + "na", + r".*Number of atoms:", + lambda attr, match: int(match.string.split()[-1]), + ), + _A( + "no", + r".*Number of basis functions:", + lambda attr, match: int(match.string.split()[-1]), + ), + _A( + "vdw_correction", + r".*\$ VdW_Correction", + lambda attr, match: True, + default=False, + ), ] @property - @deprecation("txtSileORCA.na is deprecated in favor of txtSileORCA.info.na", "0.16.0") + @deprecation( + "txtSileORCA.na is deprecated in favor of txtSileORCA.info.na", "0.16.0" + ) def na(self): - """ Number of atoms """ + """Number of atoms""" return self.info.na @property - @deprecation("txtSileORCA.no is deprecated in favor of txtSileORCA.info.no", "0.16.0") + @deprecation( + "txtSileORCA.no is deprecated in favor of txtSileORCA.info.no", "0.16.0" + ) def no(self): - """ Number of orbitals (basis functions) """ + """Number of orbitals (basis functions)""" return self.info.no @SileBinder(postprocess=np.array) @sile_fh_open() def read_electrons(self): - """ Read number of electrons (alpha, beta) + """Read number of electrons (alpha, beta) Returns ------- @@ -63,7 +77,7 @@ def read_electrons(self): @SileBinder() @sile_fh_open() def read_energy(self): - """ Reads the energy blocks + """Reads the energy blocks Returns ------- @@ -73,11 +87,11 @@ def read_energy(self): f = self.step_to("$ DFT_Energy", allow_reread=False)[0] if not f: return None - self.readline() # description - self.readline() # geom. index - self.readline() # prop. index + self.readline() # description + self.readline() # geom. index + self.readline() # prop. index - Ha2eV = units('Ha', 'eV') + Ha2eV = units("Ha", "eV") E = PropertyDict() line = self.readline() @@ -109,7 +123,7 @@ def read_energy(self): @SileBinder() @sile_fh_open() def read_geometry(self): - """ Reads the geometry from ORCA property.txt file + """Reads the geometry from ORCA property.txt file Returns ------- @@ -122,8 +136,8 @@ def read_geometry(self): line = self.readline() na = int(line.split()[-1]) - self.readline() # skip Geometry index - self.readline() # skip Coordinates line + self.readline() # skip Geometry index + self.readline() # skip Coordinates line atoms = [] xyz = np.empty([na, 3], np.float64) @@ -135,4 +149,4 @@ def read_geometry(self): return Geometry(xyz, atoms) -add_sile('txt', txtSileORCA, gzip=True) +add_sile("txt", txtSileORCA, gzip=True) diff --git a/src/sisl/io/pdb.py b/src/sisl/io/pdb.py index 7f2bff04e5..f8e8303b46 100644 --- a/src/sisl/io/pdb.py +++ b/src/sisl/io/pdb.py @@ -14,40 +14,40 @@ # Import sile objects from .sile import * -__all__ = ['pdbSile'] +__all__ = ["pdbSile"] @set_module("sisl.io") class pdbSile(Sile): - r""" PDB file object """ + r"""PDB file object""" def _setup(self, *args, **kwargs): - """ Instantiate counters """ + """Instantiate counters""" super()._setup(*args, **kwargs) self._model = 1 self._serial = 1 self._wrote_header = False def _w_sisl(self): - """ Placeholder for adding the header information """ + """Placeholder for adding the header information""" if self._wrote_header: return self._wrote_header = True - self._write('EXPDTA {:60s}\n'.format("THEORETICAL MODEL")) + self._write("EXPDTA {:60s}\n".format("THEORETICAL MODEL")) # Add dates, AUTHOR etc. def _w_model(self, start): - """ Writes the start of the next model """ + """Writes the start of the next model""" if start: - self._write(f'MODEL {self._model}\n') + self._write(f"MODEL {self._model}\n") self._model += 1 # Serial counter self._serial = 1 else: - self._write('ENDMDL\n') + self._write("ENDMDL\n") def _step_record(self, record, reread=True): - """ Step to a specific record entry in the PDB file """ + """Step to a specific record entry in the PDB file""" found = False # The previously read line number @@ -55,11 +55,11 @@ def _step_record(self, record, reread=True): while not found: l = self.readline() - if l == '': + if l == "": break found = l.startswith(record) - if not found and (l == '' and line > 0) and reread: + if not found and (l == "" and line > 0) and reread: # We may be in the case where the user request # reading the same twice... # So we need to re-read the file... @@ -70,65 +70,73 @@ def _step_record(self, record, reread=True): # Try and read again while not found and self._line <= line: l = self.readline() - if l == '': + if l == "": break found = l.startswith(record) return found, l @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def write_lattice(self, lattice): - """ Writes the supercell to the contained file """ + """Writes the supercell to the contained file""" # Check that we can write to the file sile_raise_write(self) # Get parameters and append the space group and specification - args = lattice.parameters() + ('P 1', 1) + args = lattice.parameters() + ("P 1", 1) - #COLUMNS DATA TYPE FIELD DEFINITION - #------------------------------------------------------------- + # COLUMNS DATA TYPE FIELD DEFINITION + # ------------------------------------------------------------- # 1 - 6 Record name "CRYST1" # 7 - 15 Real(9.3) a a (Angstroms). - #16 - 24 Real(9.3) b b (Angstroms). - #25 - 33 Real(9.3) c c (Angstroms). - #34 - 40 Real(7.2) alpha alpha (degrees). - #41 - 47 Real(7.2) beta beta (degrees). - #48 - 54 Real(7.2) gamma gamma (degrees). - #56 - 66 LString sGroup Space group. - #67 - 70 Integer z Z value. - self._write(('CRYST1' + '{:9.3f}' * 3 + '{:7.2f}' * 3 + '{:<11s}' + '{:4d}\n').format(*args)) - - #COLUMNS DATA TYPE FIELD DEFINITION - #------------------------------------------------------------------ + # 16 - 24 Real(9.3) b b (Angstroms). + # 25 - 33 Real(9.3) c c (Angstroms). + # 34 - 40 Real(7.2) alpha alpha (degrees). + # 41 - 47 Real(7.2) beta beta (degrees). + # 48 - 54 Real(7.2) gamma gamma (degrees). + # 56 - 66 LString sGroup Space group. + # 67 - 70 Integer z Z value. + self._write( + ("CRYST1" + "{:9.3f}" * 3 + "{:7.2f}" * 3 + "{:<11s}" + "{:4d}\n").format( + *args + ) + ) + + # COLUMNS DATA TYPE FIELD DEFINITION + # ------------------------------------------------------------------ # 1 - 6 Record name "SCALEn" n=1, 2, or 3 - #11 - 20 Real(10.6) s[n][1] Sn1 - #21 - 30 Real(10.6) s[n][2] Sn2 - #31 - 40 Real(10.6) s[n][3] Sn3 - #46 - 55 Real(10.5) u[n] Un + # 11 - 20 Real(10.6) s[n][1] Sn1 + # 21 - 30 Real(10.6) s[n][2] Sn2 + # 31 - 40 Real(10.6) s[n][3] Sn3 + # 46 - 55 Real(10.5) u[n] Un for i in range(3): - args = [i + 1] + lattice.cell[i, :].tolist() + [0.] - self._write(('SCALE{:1d} {:10.6f}{:10.6f}{:10.6f} {:10.5f}\n').format(*args)) + args = [i + 1] + lattice.cell[i, :].tolist() + [0.0] + self._write( + ("SCALE{:1d} {:10.6f}{:10.6f}{:10.6f} {:10.5f}\n").format(*args) + ) - #COLUMNS DATA TYPE FIELD DEFINITION - #---------------------------------------------------------------- + # COLUMNS DATA TYPE FIELD DEFINITION + # ---------------------------------------------------------------- # 1 - 6 Record name "ORIGXn" n=1, 2, or 3 - #11 - 20 Real(10.6) o[n][1] On1 - #21 - 30 Real(10.6) o[n][2] On2 - #31 - 40 Real(10.6) o[n][3] On3 - #46 - 55 Real(10.5) t[n] Tn - fmt = 'ORIGX{:1d} ' + '{:10.6f}' * 3 + '{:10.5f}\n' + # 11 - 20 Real(10.6) o[n][1] On1 + # 21 - 30 Real(10.6) o[n][2] On2 + # 31 - 40 Real(10.6) o[n][3] On3 + # 46 - 55 Real(10.5) t[n] Tn + fmt = "ORIGX{:1d} " + "{:10.6f}" * 3 + "{:10.5f}\n" for i in range(3): args = [i + 1, 0, 0, 0, lattice.origin[i]] self._write(fmt.format(*args)) @sile_fh_open() def read_lattice(self): - """ Read supercell from the contained file """ - f, line = self._step_record('CRYST1') + """Read supercell from the contained file""" + f, line = self._step_record("CRYST1") if not f: - raise SileError(str(self) + ' does not contain a CRYST1 record.') + raise SileError(str(self) + " does not contain a CRYST1 record.") a = float(line[6:15]) b = float(line[15:24]) c = float(line[24:33]) @@ -137,21 +145,21 @@ def read_lattice(self): gamma = float(line[47:54]) cell = Lattice.tocell([a, b, c, alpha, beta, gamma]) - f, line = self._step_record('SCALE1') + f, line = self._step_record("SCALE1") if f: cell[0, :] = float(line[11:20]), float(line[21:30]), float(line[31:40]) - f, line = self._step_record('SCALE2') + f, line = self._step_record("SCALE2") if not f: - raise SileError(str(self) + ' found SCALE1 but not SCALE2!') + raise SileError(str(self) + " found SCALE1 but not SCALE2!") cell[1, :] = float(line[11:20]), float(line[21:30]), float(line[31:40]) - f, line = self._step_record('SCALE3') + f, line = self._step_record("SCALE3") if not f: - raise SileError(str(self) + ' found SCALE1 but not SCALE3!') + raise SileError(str(self) + " found SCALE1 but not SCALE3!") cell[2, :] = float(line[11:20]), float(line[21:30]), float(line[31:40]) origin = np.zeros(3) for i in range(3): - f, line = self._step_record('ORIGX{}'.format(i + 1)) + f, line = self._step_record("ORIGX{}".format(i + 1)) if f: origin[i] = float(line[45:55]) @@ -159,7 +167,7 @@ def read_lattice(self): @sile_fh_open() def write_geometry(self, geometry): - """ Writes the geometry to the contained file + """Writes the geometry to the contained file Parameters ---------- @@ -172,31 +180,54 @@ def write_geometry(self, geometry): self._w_model(True) # Generically the configuration (model) are non-polymers, hence we use the HETATM type - atom = 'HETATM' + atom = "HETATM" - #COLUMNS DATA TYPE FIELD DEFINITION - #------------------------------------------------------------------------------------- + # COLUMNS DATA TYPE FIELD DEFINITION + # ------------------------------------------------------------------------------------- # 1 - 6 Record name "ATOM " # 7 - 11 Integer serial Atom serial number. - #13 - 16 Atom name Atom name. - #17 Character altLoc Alternate location indicator. - #18 - 20 Residue name resName Residue name. - #22 Character chainID Chain identifier. - #23 - 26 Integer resSeq Residue sequence number. - #27 AChar iCode Code for insertion of residues. - #31 - 38 Real(8.3) x Orthogonal coordinates for X in Angstroms. - #39 - 46 Real(8.3) y Orthogonal coordinates for Y in Angstroms. - #47 - 54 Real(8.3) z Orthogonal coordinates for Z in Angstroms. - #55 - 60 Real(6.2) occupancy Occupancy. - #61 - 66 Real(6.2) tempFactor Temperature factor. - #77 - 78 LString(2) element Element symbol, right-justified. - #79 - 80 LString(2) charge Charge on the atom. - fmt = f'{atom:<6s}' + '{:5d} {:<4s}{:1s}{:<3s} {:1s}{:4d}{:1s} ' + '{:8.3f}' * 3 + '{:6.2f}' * 2 + ' ' * 10 + '{:2s}' * 2 + '\n' + # 13 - 16 Atom name Atom name. + # 17 Character altLoc Alternate location indicator. + # 18 - 20 Residue name resName Residue name. + # 22 Character chainID Chain identifier. + # 23 - 26 Integer resSeq Residue sequence number. + # 27 AChar iCode Code for insertion of residues. + # 31 - 38 Real(8.3) x Orthogonal coordinates for X in Angstroms. + # 39 - 46 Real(8.3) y Orthogonal coordinates for Y in Angstroms. + # 47 - 54 Real(8.3) z Orthogonal coordinates for Z in Angstroms. + # 55 - 60 Real(6.2) occupancy Occupancy. + # 61 - 66 Real(6.2) tempFactor Temperature factor. + # 77 - 78 LString(2) element Element symbol, right-justified. + # 79 - 80 LString(2) charge Charge on the atom. + fmt = ( + f"{atom:<6s}" + + "{:5d} {:<4s}{:1s}{:<3s} {:1s}{:4d}{:1s} " + + "{:8.3f}" * 3 + + "{:6.2f}" * 2 + + " " * 10 + + "{:2s}" * 2 + + "\n" + ) xyz = geometry.xyz # Current U is used for "UNKNOWN" input. Possibly the user can specify this later. for ia in geometry: a = geometry.atoms[ia] - args = [self._serial, a.tag, 'U', 'U1', 'U', 1, 'U', xyz[ia, 0], xyz[ia, 1], xyz[ia, 2], a.q0.sum(), 0., a.symbol, '0'] + args = [ + self._serial, + a.tag, + "U", + "U1", + "U", + 1, + "U", + xyz[ia, 0], + xyz[ia, 1], + xyz[ia, 2], + a.q0.sum(), + 0.0, + a.symbol, + "0", + ] # Step serial self._serial += 1 self._write(fmt.format(*args)) @@ -206,13 +237,13 @@ def write_geometry(self, geometry): @sile_fh_open() def read_geometry(self): - """ Read geometry from the contained file """ + """Read geometry from the contained file""" # First we read in the geometry lattice = self.read_lattice() # Try and go to the first model record - in_model, l = self._step_record('MODEL') + in_model, l = self._step_record("MODEL") idx = [] tags = [] @@ -220,10 +251,13 @@ def read_geometry(self): Z = [] if in_model: l = self.readline() + def is_atom(line): - return l.startswith('ATOM') or l.startswith('HETATM') + return l.startswith("ATOM") or l.startswith("HETATM") + def is_end_model(line): - return l.startswith('ENDMDL') or l == '' + return l.startswith("ENDMDL") or l == "" + while not is_end_model(l): if is_atom(l): idx.append(int(l[6:11])) @@ -241,10 +275,10 @@ def is_end_model(line): return Geometry(xyz, atoms, lattice=lattice) def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('pdb', pdbSile, case=False, gzip=True) +add_sile("pdb", pdbSile, case=False, gzip=True) diff --git a/src/sisl/io/scaleup/__init__.py b/src/sisl/io/scaleup/__init__.py index 1b066f3291..f617220ed4 100644 --- a/src/sisl/io/scaleup/__init__.py +++ b/src/sisl/io/scaleup/__init__.py @@ -13,7 +13,7 @@ rhamSileScaleUp - Hamiltonian file """ -from .sile import * # isort: split +from .sile import * # isort: split from .orbocc import * from .ref import * from .rham import * diff --git a/src/sisl/io/scaleup/orbocc.py b/src/sisl/io/scaleup/orbocc.py index da19dc0f43..a195c4fe9e 100644 --- a/src/sisl/io/scaleup/orbocc.py +++ b/src/sisl/io/scaleup/orbocc.py @@ -8,21 +8,22 @@ from ..sile import * from .sile import SileScaleUp -__all__ = ['orboccSileScaleUp'] +__all__ = ["orboccSileScaleUp"] class orboccSileScaleUp(SileScaleUp): - """ orbocc file object for ScaleUp """ + """orbocc file object for ScaleUp""" @sile_fh_open() def read_atom(self): - """ Reads a the atoms and returns an `Atoms` object """ + """Reads a the atoms and returns an `Atoms` object""" self.readline() _, ns = map(int, self.readline().split()[:2]) - species = self.readline().split()[:ns] # species - orbs = self.readline().split()[:ns] # orbs per species + species = self.readline().split()[:ns] # species + orbs = self.readline().split()[:ns] # orbs per species # Create list of species with correct # of orbitals per specie species = [Atom(s, [-1] * int(o)) for s, o in zip(species, orbs)] return Atoms(species) -add_sile('orbocc', orboccSileScaleUp, case=False, gzip=True) + +add_sile("orbocc", orboccSileScaleUp, case=False, gzip=True) diff --git a/src/sisl/io/scaleup/ref.py b/src/sisl/io/scaleup/ref.py index 7d0e1d530d..27c3a8870c 100644 --- a/src/sisl/io/scaleup/ref.py +++ b/src/sisl/io/scaleup/ref.py @@ -12,22 +12,22 @@ from ..sile import * from .sile import SileScaleUp -__all__ = ['refSileScaleUp', 'restartSileScaleUp'] +__all__ = ["refSileScaleUp", "restartSileScaleUp"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ang2Bohr = unit_convert('Ang', 'Bohr') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ang2Bohr = unit_convert("Ang", "Bohr") class refSileScaleUp(SileScaleUp): - """ REF file object for ScaleUp """ + """REF file object for ScaleUp""" @sile_fh_open() def read_lattice(self): - """ Reads a supercell from the Sile """ + """Reads a supercell from the Sile""" # 1st line is number of supercells nsc = _a.fromiteri(map(int, self.readline().split()[:3])) - self.readline() # natoms, nspecies - self.readline() # species + self.readline() # natoms, nspecies + self.readline() # species cell = _a.fromiterd(map(float, self.readline().split()[:9])) # Typically ScaleUp uses very large unit-cells # so supercells will typically be restricted to [3, 3, 3] @@ -35,7 +35,7 @@ def read_lattice(self): @sile_fh_open() def read_geometry(self, primary=False, **kwargs): - """ Reads a geometry from the Sile """ + """Reads a geometry from the Sile""" # 1st line is number of supercells nsc = _a.fromiteri(map(int, self.readline().split()[:3])) na, ns = map(int, self.readline().split()[:2]) @@ -61,15 +61,15 @@ def read_geometry(self, primary=False, **kwargs): cell[2, :] /= nsc[2] except Exception: c = np.empty([3, 3], np.float64) - c[0, 0] = 1. + cell[0] - c[0, 1] = cell[5] / 2. - c[0, 2] = cell[4] / 2. - c[1, 0] = cell[5] / 2. - c[1, 1] = 1. + cell[1] - c[1, 2] = cell[3] / 2. - c[2, 0] = cell[4] / 2. - c[2, 1] = cell[3] / 2. - c[2, 2] = 1. + cell[2] + c[0, 0] = 1.0 + cell[0] + c[0, 1] = cell[5] / 2.0 + c[0, 2] = cell[4] / 2.0 + c[1, 0] = cell[5] / 2.0 + c[1, 1] = 1.0 + cell[1] + c[1, 2] = cell[3] / 2.0 + c[2, 0] = cell[4] / 2.0 + c[2, 1] = cell[3] / 2.0 + c[2, 2] = 1.0 + cell[2] cell = c * Ang2Bohr lattice = Lattice(cell * Bohr2Ang, nsc=nsc) @@ -79,7 +79,6 @@ def read_geometry(self, primary=False, **kwargs): # Read the geometry for ia in range(na * ns): - # Retrieve line # ix iy iz ia is x y z line = self.readline().split() @@ -90,28 +89,28 @@ def read_geometry(self, primary=False, **kwargs): return Geometry(xyz * Bohr2Ang, atoms, lattice=lattice) @sile_fh_open() - def write_geometry(self, geometry, fmt='18.8e'): - """ Writes the geometry to the contained file """ + def write_geometry(self, geometry, fmt="18.8e"): + """Writes the geometry to the contained file""" # Check that we can write to the file sile_raise_write(self) # 1st line is number of supercells - self._write('{:5d}{:5d}{:5d}\n'.format(*geometry.lattice.nsc // 2 + 1)) + self._write("{:5d}{:5d}{:5d}\n".format(*geometry.lattice.nsc // 2 + 1)) # natoms, nspecies - self._write('{:5d}{:5d}\n'.format(len(geometry), len(geometry.atoms.atom))) + self._write("{:5d}{:5d}\n".format(len(geometry), len(geometry.atoms.atom))) - s = '' + s = "" for a in geometry.atoms.atom: # Append the species label - s += f'{a.tag:<10}' - self._write(s + '\n') + s += f"{a.tag:<10}" + self._write(s + "\n") - fmt_str = f'{{:{fmt}}} ' * 9 + '\n' - self._write(fmt_str.format(*(geometry.cell*Ang2Bohr).reshape(-1))) + fmt_str = f"{{:{fmt}}} " * 9 + "\n" + self._write(fmt_str.format(*(geometry.cell * Ang2Bohr).reshape(-1))) # Create line # ix iy iz ia is x y z - line = '{:5d}{:5d}{:5d}{:5d}{:5d}' + f'{{:{fmt}}}' * 3 + '\n' + line = "{:5d}{:5d}{:5d}{:5d}{:5d}" + f"{{:{fmt}}}" * 3 + "\n" args = [None] * 8 for _, isc in geometry.lattice: @@ -120,7 +119,6 @@ def write_geometry(self, geometry, fmt='18.8e'): # Write the geometry for ia in geometry: - args[0] = isc[0] args[1] = isc[1] args[2] = isc[2] @@ -133,7 +131,7 @@ def write_geometry(self, geometry, fmt='18.8e'): self._write(line.format(*args)) def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) @@ -141,10 +139,9 @@ def ArgumentParser(self, p=None, *args, **kwargs): # The restart file is _equivalent_ but with displacements class restartSileScaleUp(refSileScaleUp): - @sile_fh_open() def read_geometry(self, *args, **kwargs): - """ Read geometry of the restart file + """Read geometry of the restart file This will also try and read the corresponding .REF file such that final coordinates are returned. @@ -163,12 +160,13 @@ def read_geometry(self, *args, **kwargs): restart = super().read_geometry() if not ref is None: - restart.lattice = Lattice(np.dot(ref.lattice.cell, restart.lattice.cell.T), - nsc=restart.nsc) + restart.lattice = Lattice( + np.dot(ref.lattice.cell, restart.lattice.cell.T), nsc=restart.nsc + ) restart.xyz += ref.xyz return restart -add_sile('REF', refSileScaleUp, case=False, gzip=True) -add_sile('restart', restartSileScaleUp, case=False, gzip=True) +add_sile("REF", refSileScaleUp, case=False, gzip=True) +add_sile("restart", restartSileScaleUp, case=False, gzip=True) diff --git a/src/sisl/io/scaleup/rham.py b/src/sisl/io/scaleup/rham.py index e59ff7c0bb..4876b400ca 100644 --- a/src/sisl/io/scaleup/rham.py +++ b/src/sisl/io/scaleup/rham.py @@ -5,21 +5,22 @@ from scipy.sparse import lil_matrix from ..sile import * + # Import sile objects from .sile import SileScaleUp -__all__ = ['rhamSileScaleUp'] +__all__ = ["rhamSileScaleUp"] class rhamSileScaleUp(SileScaleUp): - """ rham file object for ScaleUp + """rham file object for ScaleUp This file contains the real-space Hamiltonian for a ScaleUp simulation """ @sile_fh_open() def read_hamiltonian(self, geometry=None): - """ Reads a Hamiltonian from the Sile """ + """Reads a Hamiltonian from the Sile""" from sisl import Hamiltonian # Create a copy as we may change the @@ -38,7 +39,7 @@ def pl(line): isc = list(map(int, l[1:4])) o1, o2 = map(int, l[4:6]) rH, iH = map(float, l[6:8]) - return s, isc, o1-1, o2-1, rH, iH + return s, isc, o1 - 1, o2 - 1, rH, iH ns = 0 no = 0 @@ -57,7 +58,9 @@ def pl(line): if no + 1 != g.no: # First try and read the orbocc file try: - species = get_sile(str(self.file).replace(".rham", ".orbocc")).read_atom() + species = get_sile( + str(self.file).replace(".rham", ".orbocc") + ).read_atom() for i, atom in enumerate(species.atom): g.atoms._atom[i] = atom except Exception: @@ -66,8 +69,10 @@ def pl(line): # Check again, to be sure... if no + 1 != g.no: - raise ValueError('The Geometry has a different number of ' - 'orbitals, please correct by adding the orbocc file.') + raise ValueError( + "The Geometry has a different number of " + "orbitals, please correct by adding the orbocc file." + ) # Now, we know the size etc. of the Hamiltonian m_sc = m_sc * 2 + 1 @@ -83,12 +88,11 @@ def pl(line): old_s = 0 for s, isc, o1, o2, rH, iH in lines: - if s != old_s: # We need to create a new Hamiltonian H = lil_matrix((no, no_s), dtype=np.float64) old_s = s - Hs[s-1] = H + Hs[s - 1] = H i = g.sc_index(isc) H[o1, o2 + i * no] = rH @@ -99,4 +103,4 @@ def pl(line): return H -add_sile('rham', rhamSileScaleUp, case=False, gzip=True) +add_sile("rham", rhamSileScaleUp, case=False, gzip=True) diff --git a/src/sisl/io/siesta/__init__.py b/src/sisl/io/siesta/__init__.py index f749bea557..6cd2d8c808 100644 --- a/src/sisl/io/siesta/__init__.py +++ b/src/sisl/io/siesta/__init__.py @@ -45,7 +45,7 @@ tsvncSileSiesta - TranSiesta potential solution input file """ -from .sile import * # isort: split +from .sile import * # isort: split from .ani import * from .bands import * from .basis import * diff --git a/src/sisl/io/siesta/_help.py b/src/sisl/io/siesta/_help.py index 3255596250..7a5862ef93 100644 --- a/src/sisl/io/siesta/_help.py +++ b/src/sisl/io/siesta/_help.py @@ -8,17 +8,18 @@ try: from . import _siesta + has_fortran_module = True except ImportError: has_fortran_module = False -__all__ = ['_csr_from_siesta', '_csr_from_sc_off'] -__all__ += ['_csr_to_siesta', '_csr_to_sc_off'] -__all__ += ['_mat_spin_convert', "_fc_correct"] +__all__ = ["_csr_from_siesta", "_csr_from_sc_off"] +__all__ += ["_csr_to_siesta", "_csr_to_sc_off"] +__all__ += ["_mat_spin_convert", "_fc_correct"] def _ensure_diagonal(csr): - """ Ensures that the sparsity pattern has diagonal entries + """Ensures that the sparsity pattern has diagonal entries This will set the wrong values in non-orthogonal basis-sets since missing items will be set to 0 which should be 1 in @@ -38,19 +39,21 @@ def _ensure_diagonal(csr): missing_diags = np.delete(np.arange(csr.shape[0]), present_diags) for row in missing_diags: - csr[row, row] = 0. + csr[row, row] = 0.0 def _csr_from_siesta(geom, csr): - """ Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc """ + """Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc""" if not has_fortran_module: - raise SislError('sisl cannot convert the sparse matrix from a Siesta conforming sparsity pattern! Please install with fortran support!') + raise SislError( + "sisl cannot convert the sparse matrix from a Siesta conforming sparsity pattern! Please install with fortran support!" + ) _csr_from_sc_off(geom, _siesta.siesta_sc_off(*geom.nsc).T, csr) def _csr_to_siesta(geom, csr, diag=True): - """ Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc + """Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc Parameters ---------- @@ -59,7 +62,9 @@ def _csr_to_siesta(geom, csr, diag=True): whether the csr matrix will be ensured diagonal as well """ if not has_fortran_module: - raise SislError('sisl cannot convert the sparse matrix into a Siesta conforming sparsity pattern! Please install with fortran support!') + raise SislError( + "sisl cannot convert the sparse matrix into a Siesta conforming sparsity pattern! Please install with fortran support!" + ) if diag: _ensure_diagonal(csr) @@ -67,34 +72,38 @@ def _csr_to_siesta(geom, csr, diag=True): def _csr_from_sc_off(geom, sc_off, csr): - """ Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc """ + """Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc""" nsc = geom.lattice.nsc.astype(np.int32) sc = geom.lattice.__class__([1], nsc=nsc) sc.sc_off = sc_off from_sc_off = sc.sc_index(geom.sc_off) # this transfers the local siesta csr matrix ordering to the geometry ordering - col_from = (from_sc_off.reshape(-1, 1) * geom.no + _a.arangei(geom.no).reshape(1, -1)).ravel() + col_from = ( + from_sc_off.reshape(-1, 1) * geom.no + _a.arangei(geom.no).reshape(1, -1) + ).ravel() _csr_from(col_from, csr) def _csr_to_sc_off(geom, sc_off, csr): - """ Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc """ + """Internal routine to convert *any* SparseCSR matrix from sisl nsc to siesta nsc""" # Find the equivalent indices in the geometry supercell to_sc_off = geom.sc_index(sc_off) # this transfers the local csr matrix ordering to the geometry ordering - col_from = (to_sc_off.reshape(-1, 1) * geom.no + _a.arangei(geom.no).reshape(1, -1)).ravel() + col_from = ( + to_sc_off.reshape(-1, 1) * geom.no + _a.arangei(geom.no).reshape(1, -1) + ).ravel() _csr_from(col_from, csr) def _csr_from(col_from, csr): - """ Internal routine to convert columns in a SparseCSR matrix """ + """Internal routine to convert columns in a SparseCSR matrix""" # local csr matrix ordering col_to = _a.arangei(csr.shape[1]) csr.translate_columns(col_from, col_to) def _mat_spin_convert(M, spin=None): - """ Conversion of Siesta spin matrices to sisl spin matrices + """Conversion of Siesta spin matrices to sisl spin matrices The matrices from Siesta are given in a format adheering to the following concept: @@ -135,7 +144,7 @@ def _mat_spin_convert(M, spin=None): def _geom2hsx(geometry): - """ Convert the geometry into the correct lists of species and lists """ + """Convert the geometry into the correct lists of species and lists""" atoms = geometry.atoms nspecie = atoms.nspecie isa = atoms.specie @@ -157,7 +166,7 @@ def _geom2hsx(geometry): def _fc_correct(fc, trans_inv=True, sum0=True, hermitian=True): - r""" Correct a force-constant matrix to retain translational invariance and sum of forces == 0 on atoms + r"""Correct a force-constant matrix to retain translational invariance and sum of forces == 0 on atoms Parameters ---------- @@ -186,7 +195,9 @@ def _fc_correct(fc, trans_inv=True, sum0=True, hermitian=True): is_subset = shape[0] != shape[5] if is_subset: - raise ValueError(f"fc_correct cannot figure out the displaced atoms in the unit-cell, please limit atoms to the displaced atoms.") + raise ValueError( + f"fc_correct cannot figure out the displaced atoms in the unit-cell, please limit atoms to the displaced atoms." + ) # NOTE: # This is not exactly the same as Vibra does it. diff --git a/src/sisl/io/siesta/ani.py b/src/sisl/io/siesta/ani.py index 98681e449b..3b808132c9 100644 --- a/src/sisl/io/siesta/ani.py +++ b/src/sisl/io/siesta/ani.py @@ -18,4 +18,4 @@ class aniSileSiesta(xyzSile, SileSiesta): pass -add_sile('ANI', aniSileSiesta, gzip=True) +add_sile("ANI", aniSileSiesta, gzip=True) diff --git a/src/sisl/io/siesta/bands.py b/src/sisl/io/siesta/bands.py index e7a99d54b8..820e2bf463 100644 --- a/src/sisl/io/siesta/bands.py +++ b/src/sisl/io/siesta/bands.py @@ -14,17 +14,17 @@ @set_module("sisl.io.siesta") class bandsSileSiesta(SileSiesta): - """ Bandstructure information """ + """Bandstructure information""" @sile_fh_open(True) def read_fermi_level(self): - """ Returns the Fermi level in the bands file """ + """Returns the Fermi level in the bands file""" # Luckily the data is in eV return float(self.readline()) @sile_fh_open() def read_data(self, as_dataarray=False): - """ Returns data associated with the bands file + """Returns data associated with the bands file The energy levels are shifted with respect to the Fermi-level. @@ -32,7 +32,7 @@ def read_data(self, as_dataarray=False): -------- as_dataarray: boolean, optional if `True`, the information is returned as an `xarray.DataArray` - Ticks (if read) are stored as an attribute of the DataArray + Ticks (if read) are stored as an attribute of the DataArray (under `array.ticks` and `array.ticklabels`) """ band_lines = False @@ -85,7 +85,7 @@ def read_data(self, as_dataarray=False): for _ in range(nl): l = self.readline().split() xlabels.append(float(l[0])) - labels.append((' '.join(l[1:])).replace("'", '')) + labels.append((" ".join(l[1:])).replace("'", "")) vals = (xlabels, labels), *vals if as_dataarray: @@ -93,17 +93,23 @@ def read_data(self, as_dataarray=False): ticks = {"ticks": xlabels, "ticklabels": labels} if band_lines else {} - vals = DataArray(eb, name="Energy", attrs=ticks, - coords=[("k", k), - ("spin", _a.arangei(0, eb.shape[1])), - ("band", _a.arangei(0, eb.shape[2]))]) + vals = DataArray( + eb, + name="Energy", + attrs=ticks, + coords=[ + ("k", k), + ("spin", _a.arangei(0, eb.shape[1])), + ("band", _a.arangei(0, eb.shape[2])), + ], + ) return vals @default_ArgumentParser(description="Manipulate bands file in sisl.") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ - #limit_args = kwargs.get("limit_arguments", True) + """Returns the arguments that is available for this Sile""" + # limit_args = kwargs.get("limit_arguments", True) short = kwargs.get("short", False) def opts(*args): @@ -119,28 +125,32 @@ def opts(*args): # This will enable custom actions to interact with the geometry in a # straight forward manner. namespace = default_namespace( - _bands= self.read_data(), + _bands=self.read_data(), _Emap=None, ) # Energy grabs class ERange(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): ns._Emap = strmap(float, value)[0] - p.add_argument("--energy", "-E", - action=ERange, - help="Denote the sub-section of energies that are plotted: '-1:0,1:2' [eV]") - class BandsPlot(argparse.Action): + p.add_argument( + "--energy", + "-E", + action=ERange, + help="Denote the sub-section of energies that are plotted: '-1:0,1:2' [eV]", + ) + class BandsPlot(argparse.Action): def __call__(self, parser, ns, value, option_string=None): import matplotlib.pyplot as plt # Decide whether this is BandLines or BandPoints if len(ns._bands) == 2: # We do not plot "points" - raise ValueError("The bands file only contains points in the BZ, not a bandstructure.") + raise ValueError( + "The bands file only contains points in the BZ, not a bandstructure." + ) lbls, k, b = ns._bands b = b.T # Extract to tick-marks and names @@ -163,7 +173,9 @@ def myplot(ax, title, x, y, E): ax[1].set_xticklabels(lbls, rotation=45) # We must plot spin-up/down separately for i, ud in enumerate(["UP", "DOWN"]): - myplot(ax[i], f"Bandstructure SPIN-{ud}", k, b[:, i, :], ns._Emap) + myplot( + ax[i], f"Bandstructure SPIN-{ud}", k, b[:, i, :], ns._Emap + ) else: plt.figure() ax = plt.gca() @@ -174,8 +186,14 @@ def myplot(ax, title, x, y, E): plt.show() else: plt.savefig(value) - p.add_argument(*opts("--plot", "-p"), action=BandsPlot, nargs="?", metavar="FILE", - help="Plot the bandstructure from the .bands file, possibly saving to a file.") + + p.add_argument( + *opts("--plot", "-p"), + action=BandsPlot, + nargs="?", + metavar="FILE", + help="Plot the bandstructure from the .bands file, possibly saving to a file.", + ) return p, namespace diff --git a/src/sisl/io/siesta/basis.py b/src/sisl/io/siesta/basis.py index 0c8ac12b07..820cefdb68 100644 --- a/src/sisl/io/siesta/basis.py +++ b/src/sisl/io/siesta/basis.py @@ -13,25 +13,25 @@ from ..sile import add_sile, sile_fh_open from .sile import SileCDFSiesta, SileSiesta -__all__ = ['ionxmlSileSiesta', 'ionncSileSiesta'] +__all__ = ["ionxmlSileSiesta", "ionncSileSiesta"] @set_module("sisl.io.siesta") class ionxmlSileSiesta(SileSiesta): - """ Basis set information in xml format + """Basis set information in xml format Note that the ``ion`` files are equivalent to the ``ion.xml`` files. """ @sile_fh_open(True) def read_basis(self): - """ Returns data associated with the ion.xml file """ + """Returns data associated with the ion.xml file""" # Get the element-tree root = xml_parse(self.fh).getroot() # Get number of orbitals label = root.find("label").text.strip() - Z = int(root.find("z").text) # atomic number, negative for floating + Z = int(root.find("z").text) # atomic number, negative for floating mass = float(root.find("mass").text) # Read in the PAO"s @@ -43,10 +43,9 @@ def read_basis(self): # All orbital data Bohr2Ang = unit_convert("Bohr", "Ang") for orb in paos: - n = int(orb.get("n")) l = int(orb.get("l")) - z = int(orb.get("z")) # zeta + z = int(orb.get("z")) # zeta q0 = float(orb.get("population")) @@ -72,7 +71,7 @@ def read_basis(self): # The fact that we have to have it normalized means that we need # to convert psi /sqrt(Bohr**3) -> /sqrt(Ang**3) # \int psi^\dagger psi == 1 - psi = dat[1::2] * r ** l / Bohr2Ang ** (3./2.) + psi = dat[1::2] * r**l / Bohr2Ang ** (3.0 / 2.0) # Create the sphericalorbital and then the atomicorbital sorb = SphericalOrbital(l, (r * Bohr2Ang, psi), q0) @@ -86,13 +85,13 @@ def read_basis(self): @set_module("sisl.io.siesta") class ionncSileSiesta(SileCDFSiesta): - """ Basis set information in NetCDF files + """Basis set information in NetCDF files Note that the ``ion.nc`` files are equivalent to the ``ion.xml`` files. """ def read_basis(self): - """ Returns data associated with the ion.xml file """ + """Returns data associated with the ion.xml file""" no = len(self._dimension("norbs")) # Get number of orbitals @@ -101,12 +100,12 @@ def read_basis(self): mass = float(self.Mass) # Retrieve values - orb_l = self._variable("orbnl_l")[:] # angular quantum number - orb_n = self._variable("orbnl_n")[:] # principal quantum number - orb_z = self._variable("orbnl_z")[:] # zeta - orb_P = self._variable("orbnl_ispol")[:] > 0 # polarization shell, or not - orb_q0 = self._variable("orbnl_pop")[:] # q0 for the orbitals - orb_delta = self._variable("delta")[:] # delta for the functions + orb_l = self._variable("orbnl_l")[:] # angular quantum number + orb_n = self._variable("orbnl_n")[:] # principal quantum number + orb_z = self._variable("orbnl_z")[:] # zeta + orb_P = self._variable("orbnl_ispol")[:] > 0 # polarization shell, or not + orb_q0 = self._variable("orbnl_pop")[:] # q0 for the orbitals + orb_delta = self._variable("delta")[:] # delta for the functions orb_psi = self._variable("orb")[:, :] # Now loop over all orbitals @@ -115,7 +114,6 @@ def read_basis(self): # All orbital data Bohr2Ang = unit_convert("Bohr", "Ang") for io in range(no): - n = orb_n[io] l = orb_l[io] z = orb_z[io] @@ -134,7 +132,7 @@ def read_basis(self): # The fact that we have to have it normalized means that we need # to convert psi /sqrt(Bohr**3) -> /sqrt(Ang**3) # \int psi^\dagger psi == 1 - psi = orb_psi[io, :] * r ** l / Bohr2Ang ** (3./2.) + psi = orb_psi[io, :] * r**l / Bohr2Ang ** (3.0 / 2.0) # Create the sphericalorbital and then the atomicorbital sorb = SphericalOrbital(l, (r * Bohr2Ang, psi), orb_q0[io]) @@ -147,8 +145,8 @@ def read_basis(self): @default_ArgumentParser(description="Extracting basis-set information.") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ - #limit_args = kwargs.get("limit_arguments", True) + """Returns the arguments that is available for this Sile""" + # limit_args = kwargs.get("limit_arguments", True) short = kwargs.get("short", False) def opts(*args): @@ -177,7 +175,7 @@ def opts(*args): # this gets converted later delta = self._variable("delta")[:] r = aranged(ion_nc.orbital.shape[1]).reshape(1, -1) * delta.reshape(-1, 1) - ion_nc.orbital *= r ** ion_nc.l.reshape(-1, 1) / Bohr2Ang * (3./2.) + ion_nc.orbital *= r ** ion_nc.l.reshape(-1, 1) / Bohr2Ang * (3.0 / 2.0) ion_nc.r = r * Bohr2Ang ion_nc.kb = PropertyDict() ion_nc.kb.n = self._variable("pjnl_n")[:] @@ -186,26 +184,26 @@ def opts(*args): ion_nc.kb.proj = self._variable("proj")[:] delta = self._variable("kbdelta")[:] r = aranged(ion_nc.kb.proj.shape[1]).reshape(1, -1) * delta.reshape(-1, 1) - ion_nc.kb.proj *= r ** ion_nc.kb.l.reshape(-1, 1) / Bohr2Ang * (3./2.) + ion_nc.kb.proj *= r ** ion_nc.kb.l.reshape(-1, 1) / Bohr2Ang * (3.0 / 2.0) ion_nc.kb.r = r * Bohr2Ang vna = self._variable("vna") r = aranged(vna[:].size) * vna.Vna_delta ion_nc.vna = PropertyDict() - ion_nc.vna.v = vna[:] * Ry2eV * r / Bohr2Ang ** 3 + ion_nc.vna.v = vna[:] * Ry2eV * r / Bohr2Ang**3 ion_nc.vna.r = r * Bohr2Ang # this is charge (not 1/sqrt(charge)) chlocal = self._variable("chlocal") r = aranged(chlocal[:].size) * chlocal.Chlocal_delta ion_nc.chlocal = PropertyDict() - ion_nc.chlocal.v = chlocal[:] * r / Bohr2Ang ** 3 + ion_nc.chlocal.v = chlocal[:] * r / Bohr2Ang**3 ion_nc.chlocal.r = r * Bohr2Ang vlocal = self._variable("reduced_vlocal") r = aranged(vlocal[:].size) * vlocal.Reduced_vlocal_delta ion_nc.vlocal = PropertyDict() - ion_nc.vlocal.v = vlocal[:] * r / Bohr2Ang ** 3 + ion_nc.vlocal.v = vlocal[:] * r / Bohr2Ang**3 ion_nc.vlocal.r = r * Bohr2Ang if "core" in self.variables: @@ -213,7 +211,7 @@ def opts(*args): core = self._variable("core") r = aranged(core[:].size) * core.Core_delta ion_nc.core = PropertyDict() - ion_nc.core.v = core[:] * r / Bohr2Ang ** 3 + ion_nc.core.v = core[:] * r / Bohr2Ang**3 ion_nc.core.r = r * Bohr2Ang d = { @@ -226,31 +224,34 @@ def opts(*args): # l-quantum number class lRange(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): - value = (value - .replace("s", 0) - .replace("p", 1) - .replace("d", 2) - .replace("f", 3) - .replace("g", 4) + value = ( + value.replace("s", 0) + .replace("p", 1) + .replace("d", 2) + .replace("f", 3) + .replace("g", 4) ) ns._l = strmap(int, value)[0] - p.add_argument("-l", - action=lRange, - help="Denote the sub-section of l-shells that are plotted: 's,f'") + + p.add_argument( + "-l", + action=lRange, + help="Denote the sub-section of l-shells that are plotted: 's,f'", + ) # n quantum number class nRange(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): ns._n = strmap(int, value)[0] - p.add_argument("-n", - action=nRange, - help="Denote the sub-section of n quantum numbers that are plotted: '2-4,6'") - class Plot(argparse.Action): + p.add_argument( + "-n", + action=nRange, + help="Denote the sub-section of n quantum numbers that are plotted: '2-4,6'", + ) + class Plot(argparse.Action): def __call__(self, parser, ns, value, option_string=None): import matplotlib.pyplot as plt @@ -273,8 +274,9 @@ def __call__(self, parser, ns, value, option_string=None): fig, axs = plt.subplots(2, 2) # Now plot different orbitals - for n, l, zeta, pol, r, orb in zip(data.n, data.l, data.zeta, - data.pol, data.r, data.orbital): + for n, l, zeta, pol, r, orb in zip( + data.n, data.l, data.zeta, data.pol, data.r, data.orbital + ): if pol == 1: pol = "P" else: @@ -287,7 +289,8 @@ def __call__(self, parser, ns, value, option_string=None): # plot projectors for n, l, e, r, proj in zip( - data.kb.n, data.kb.l, data.kb.e, data.kb.r, data.kb.proj): + data.kb.n, data.kb.l, data.kb.e, data.kb.r, data.kb.proj + ): axs[0][1].plot(r, proj, label=f"n{n}l{l} e={e:.5f}") axs[0][1].set_title("KB projectors") axs[0][1].set_xlabel("Distance [Ang]") @@ -313,8 +316,14 @@ def __call__(self, parser, ns, value, option_string=None): plt.show() else: plt.savefig(value) - p.add_argument(*opts("--plot", "-p"), action=Plot, nargs="?", metavar="FILE", - help="Plot the content basis set file, possibly saving plot to a file.") + + p.add_argument( + *opts("--plot", "-p"), + action=Plot, + nargs="?", + metavar="FILE", + help="Plot the content basis set file, possibly saving plot to a file.", + ) return p, namespace diff --git a/src/sisl/io/siesta/binaries.py b/src/sisl/io/siesta/binaries.py index 5cbb7cc6d7..d7bc1609b8 100644 --- a/src/sisl/io/siesta/binaries.py +++ b/src/sisl/io/siesta/binaries.py @@ -9,6 +9,7 @@ try: from . import _siesta + has_fortran_module = True except ImportError: has_fortran_module = False @@ -55,7 +56,7 @@ def _toF(array, dtype, scale=None): def _geometry_align(geom_b, geom_u, cls, method): - """ Routine used to align two geometries + """Routine used to align two geometries There are a few twists in this since the fdf-reads will automatically try and pass a geometry from the output files. @@ -87,6 +88,7 @@ def _geometry_align(geom_b, geom_u, cls, method): geom = geom_u is_copy = False + def get_copy(geom, is_copy): if is_copy: return geom, True @@ -94,20 +96,28 @@ def get_copy(geom, is_copy): if geom_b.na != geom.na: # we have no way of solving this issue... - raise SileError(f"{cls.__name__}.{method} could not use the passed geometry as the " - f"of atoms is not consistent, user-atoms={geom_u.na}, file-atoms={geom_b.na}.") + raise SileError( + f"{cls.__name__}.{method} could not use the passed geometry as the " + f"of atoms is not consistent, user-atoms={geom_u.na}, file-atoms={geom_b.na}." + ) # Try and figure out what to do if not np.allclose(geom_b.xyz, geom.xyz): - warn(f"{cls.__name__}.{method} has mismatched atomic coordinates, will copy geometry and use file XYZ.") + warn( + f"{cls.__name__}.{method} has mismatched atomic coordinates, will copy geometry and use file XYZ." + ) geom, is_copy = get_copy(geom, is_copy) geom.xyz[:, :] = geom_b.xyz[:, :] if not np.allclose(geom_b.lattice.cell, geom.lattice.cell): - warn(f"{cls.__name__}.{method} has non-equal lattice vectors, will copy geometry and use file lattice.") + warn( + f"{cls.__name__}.{method} has non-equal lattice vectors, will copy geometry and use file lattice." + ) geom, is_copy = get_copy(geom, is_copy) geom.lattice.cell[:, :] = geom_b.lattice.cell[:, :] if not np.array_equal(geom_b.nsc, geom.nsc): - warn(f"{cls.__name__}.{method} has non-equal number of supercells, will copy geometry and use file supercell count.") + warn( + f"{cls.__name__}.{method} has non-equal number of supercells, will copy geometry and use file supercell count." + ) geom, is_copy = get_copy(geom, is_copy) geom.set_nsc(geom_b.nsc) @@ -116,19 +126,23 @@ def get_copy(geom, is_copy): # prefer to use the user-supplied atomic species, but fill with # *random* orbitals if not np.array_equal(geom_b.atoms.orbitals, geom.atoms.orbitals): - warn(f"{cls.__name__}.{method} has non-equal number of orbitals per atom, will correct with *empty* orbitals.") + warn( + f"{cls.__name__}.{method} has non-equal number of orbitals per atom, will correct with *empty* orbitals." + ) geom, is_copy = get_copy(geom, is_copy) # Now create a new atom specie with the correct number of orbitals norbs = geom_b.atoms.orbitals[:] - atoms = Atoms([geom.atoms[i].copy(orbitals=[-1.] * norbs[i]) for i in range(geom.na)]) + atoms = Atoms( + [geom.atoms[i].copy(orbitals=[-1.0] * norbs[i]) for i in range(geom.na)] + ) geom._atoms = atoms return geom def _add_overlap(M, S, str_method): - """ Adds the overlap matrix to the matrix `M` + """Adds the overlap matrix to the matrix `M` Handles different cases of `S` """ @@ -142,21 +156,23 @@ def _add_overlap(M, S, str_method): if S.non_orthogonal: M._csr._D[:, M.S_idx] = S._csr._D[:, S.S_idx] else: - raise NotImplementedError(f"{str_method} could not paste overlap matrix into the " - "matrix due to non-conforming sparse elements.") + raise NotImplementedError( + f"{str_method} could not paste overlap matrix into the " + "matrix due to non-conforming sparse elements." + ) @set_module("sisl.io.siesta") class onlysSileSiesta(SileBinSiesta): - """ Geometry and overlap matrix """ + """Geometry and overlap matrix""" @property def version(self) -> int: - """ The version of the file """ + """The version of the file""" return _siesta.read_tshs_version(self.file) def read_lattice(self): - """ Returns a Lattice object from a TranSiesta file """ + """Returns a Lattice object from a TranSiesta file""" n_s = _siesta.read_tshs_sizes(self.file)[3] self._fortran_check("read_lattice", "could not read sizes.") arr = _siesta.read_tshs_cell(self.file, n_s) @@ -172,7 +188,7 @@ def read_lattice(self): return Lattice(cell, nsc=nsc) def read_geometry(self, geometry=None): - """ Returns Geometry object from a TranSiesta file """ + """Returns Geometry object from a TranSiesta file""" # Read supercell lattice = self.read_lattice() @@ -200,7 +216,7 @@ def read_geometry(self, geometry=None): # Create atoms atoms = [] for Z, orb in enumerate(uorb): - atoms.append(Atom(Z+1, [-1] * orb)) + atoms.append(Atom(Z + 1, [-1] * orb)) def get_atom(atoms, orbs): for atom in atoms: @@ -220,15 +236,21 @@ def get_atom(atoms, orbs): atom.append(a) else: # correct atom - atom.append(a.__class__(a.Z, [-1. for io in range(no)], mass=a.mass, tag=a.tag)) + atom.append( + a.__class__( + a.Z, [-1.0 for io in range(no)], mass=a.mass, tag=a.tag + ) + ) # Create and return geometry object return Geometry(xyz, atom, lattice=lattice) def read_overlap(self, **kwargs): - """ Returns the overlap matrix from the TranSiesta file """ + """Returns the overlap matrix from the TranSiesta file""" tshs_g = self.read_geometry() - geom = _geometry_align(tshs_g, kwargs.get("geometry", tshs_g), self.__class__, "read_overlap") + geom = _geometry_align( + tshs_g, kwargs.get("geometry", tshs_g), self.__class__, "read_overlap" + ) # read the sizes used... sizes = _siesta.read_tshs_sizes(self.file) @@ -264,7 +286,7 @@ def read_overlap(self, **kwargs): return S.transpose(sort=kwargs.get("sort", True)) def read_fermi_level(self): - r""" Query the Fermi-level contained in the file + r"""Query the Fermi-level contained in the file Returns ------- @@ -277,10 +299,10 @@ def read_fermi_level(self): @set_module("sisl.io.siesta") class tshsSileSiesta(onlysSileSiesta): - """ Geometry, Hamiltonian and overlap matrix file """ + """Geometry, Hamiltonian and overlap matrix file""" def read_hamiltonian(self, geometry=None, **kwargs): - """ Electronic structure from the siesta.TSHS file + """Electronic structure from the siesta.TSHS file The TSHS file format does *not* contain exact orbital information. When reading the Hamiltonian directly using this class one will find @@ -325,7 +347,9 @@ def read_hamiltonian(self, geometry=None, **kwargs): no = sizes[2] nnz = sizes[4] ncol, col, dH, dS = _siesta.read_tshs_hs(self.file, spin, no, nnz) - self._fortran_check("read_hamiltonian", "could not read Hamiltonian and overlap matrix.") + self._fortran_check( + "read_hamiltonian", "could not read Hamiltonian and overlap matrix." + ) # Check whether it is an orthogonal basis set orthogonal = np.abs(dS).sum() == geom.no @@ -344,7 +368,7 @@ def read_hamiltonian(self, geometry=None, **kwargs): H._csr._D = _a.emptyd([nnz, spin]) H._csr._D[:, :] = dH[:, :] * _Ry2eV else: - H._csr._D = _a.emptyd([nnz, spin+1]) + H._csr._D = _a.emptyd([nnz, spin + 1]) H._csr._D[:, :spin] = dH[:, :] * _Ry2eV H._csr._D[:, spin] = dS[:] @@ -355,25 +379,29 @@ def read_hamiltonian(self, geometry=None, **kwargs): _csr_from_sc_off(H.geometry, isc, H._csr) # Find all indices where dS == 1 (remember col is in fortran indices) - idx = col[np.isclose(dS, 1.).nonzero()[0]] + idx = col[np.isclose(dS, 1.0).nonzero()[0]] if np.any(idx > no): print(f"Number of orbitals: {no}") print(idx) - raise SileError(f"{self!s}.read_hamiltonian could not assert " - "the supercell connections in the primary unit-cell.") + raise SileError( + f"{self!s}.read_hamiltonian could not assert " + "the supercell connections in the primary unit-cell." + ) # see onlysSileSiesta.read_overlap for .transpose() # For H, DM and EDM we also need to Hermitian conjugate it. return H.transpose(spin=False, sort=kwargs.get("sort", True)) def write_hamiltonian(self, H, **kwargs): - """ Writes the Hamiltonian to a siesta.TSHS file """ + """Writes the Hamiltonian to a siesta.TSHS file""" # we sort below, so no need to do it here # see onlysSileSiesta.read_overlap for .transpose() csr = H.transpose(spin=False, sort=False)._csr if csr.nnz == 0: - raise SileError(f"{self!s}.write_hamiltonian cannot write " - "a zero element sparse matrix!") + raise SileError( + f"{self!s}.write_hamiltonian cannot write " + "a zero element sparse matrix!" + ) # Convert to siesta CSR _csr_to_siesta(H.geometry, csr) @@ -387,16 +415,18 @@ def write_hamiltonian(self, H, **kwargs): # Get H and S if H.orthogonal: h = csr._D - s = csr.diags(1., dim=1) + s = csr.diags(1.0, dim=1) # Ensure all data is correctly formatted (i.e. have the same sparsity pattern) s.align(csr) s.finalize(sort=kwargs.get("sort", True)) if s.nnz != len(h): - raise SislError("The diagonal elements of your orthogonal Hamiltonian " - "have not been defined, this is a requirement.") + raise SislError( + "The diagonal elements of your orthogonal Hamiltonian " + "have not been defined, this is a requirement." + ) s = s._D[:, 0] else: - h = csr._D[:, :H.S_idx] + h = csr._D[:, : H.S_idx] s = csr._D[:, H.S_idx] # Get shorter variants @@ -404,20 +434,31 @@ def write_hamiltonian(self, H, **kwargs): isc = _siesta.siesta_sc_off(*nsc) # see onlysSileSiesta.read_lattice for .T - _siesta.write_tshs_hs(self.file, nsc[0], nsc[1], nsc[2], - cell.T / _Bohr2Ang, xyz.T / _Bohr2Ang, H.geometry.firsto, - csr.ncol, csr.col + 1, - _toF(h, np.float64, _eV2Ry), _toF(s, np.float64), - isc) - self._fortran_check("write_hamiltonian", "could not write Hamiltonian and overlap matrix.") + _siesta.write_tshs_hs( + self.file, + nsc[0], + nsc[1], + nsc[2], + cell.T / _Bohr2Ang, + xyz.T / _Bohr2Ang, + H.geometry.firsto, + csr.ncol, + csr.col + 1, + _toF(h, np.float64, _eV2Ry), + _toF(s, np.float64), + isc, + ) + self._fortran_check( + "write_hamiltonian", "could not write Hamiltonian and overlap matrix." + ) @set_module("sisl.io.siesta") class dmSileSiesta(SileBinSiesta): - """ Density matrix file """ + """Density matrix file""" def read_density_matrix(self, **kwargs): - """ Returns the density matrix from the siesta.DM file + """Returns the density matrix from the siesta.DM file Parameters ---------- @@ -428,7 +469,9 @@ def read_density_matrix(self, **kwargs): """ # Now read the sizes used... spin, no, nsc, nnz = _siesta.read_dm_sizes(self.file) - self._fortran_check("read_density_matrix", "could not read density matrix sizes.") + self._fortran_check( + "read_density_matrix", "could not read density matrix sizes." + ) ncol, col, dDM = _siesta.read_dm(self.file, spin, no, nsc, nnz) self._fortran_check("read_density_matrix", "could not read density matrix.") @@ -447,9 +490,11 @@ def read_density_matrix(self, **kwargs): geom.set_nsc(nsc) if geom.no != no: - raise SileError(f"{self!s}.read_density_matrix could not use the " - "passed geometry as the number of atoms or orbitals is " - "inconsistent with DM file.") + raise SileError( + f"{self!s}.read_density_matrix could not use the " + "passed geometry as the number of atoms or orbitals is " + "inconsistent with DM file." + ) # Create the density matrix container DM = DensityMatrix(geom, spin, nnzpr=1, dtype=np.float64, orthogonal=False) @@ -461,10 +506,10 @@ def read_density_matrix(self, **kwargs): DM._csr.col = col.astype(np.int32, copy=False) - 1 DM._csr._nnz = len(col) - DM._csr._D = _a.emptyd([nnz, spin+1]) + DM._csr._D = _a.emptyd([nnz, spin + 1]) DM._csr._D[:, :spin] = dDM[:, :] # DM file does not contain overlap matrix... so neglect it for now. - DM._csr._D[:, spin] = 0. + DM._csr._D[:, spin] = 0.0 _mat_spin_convert(DM) @@ -475,17 +520,22 @@ def read_density_matrix(self, **kwargs): warn(f"{self!s}.read_density_matrix may result in a wrong sparse pattern!") DM = DM.transpose(spin=False, sort=kwargs.get("sort", True)) - _add_overlap(DM, kwargs.get("overlap", None), - f"{self.__class__.__name__}.read_density_matrix") + _add_overlap( + DM, + kwargs.get("overlap", None), + f"{self.__class__.__name__}.read_density_matrix", + ) return DM def write_density_matrix(self, DM, **kwargs): - """ Writes the density matrix to a siesta.DM file """ + """Writes the density matrix to a siesta.DM file""" csr = DM.transpose(spin=False, sort=False)._csr # This ensures that we don"t have any *empty* elements if csr.nnz == 0: - raise SileError(f"{self!s}.write_density_matrix cannot write " - "a zero element sparse matrix!") + raise SileError( + f"{self!s}.write_density_matrix cannot write " + "a zero element sparse matrix!" + ) _csr_to_siesta(DM.geometry, csr) # We do not really need to sort this one, but we do for consistency @@ -497,7 +547,7 @@ def write_density_matrix(self, DM, **kwargs): if DM.orthogonal: dm = csr._D else: - dm = csr._D[:, :DM.S_idx] + dm = csr._D[:, : DM.S_idx] # Ensure shapes (say if only 1 spin) dm.shape = (-1, len(DM.spin)) @@ -510,10 +560,10 @@ def write_density_matrix(self, DM, **kwargs): @set_module("sisl.io.siesta") class tsdeSileSiesta(dmSileSiesta): - """ Non-equilibrium density matrix and energy density matrix file """ + """Non-equilibrium density matrix and energy density matrix file""" def read_energy_density_matrix(self, **kwargs): - """ Returns the energy density matrix from the siesta.TSDE file + """Returns the energy density matrix from the siesta.TSDE file Parameters ---------- @@ -524,9 +574,13 @@ def read_energy_density_matrix(self, **kwargs): """ # Now read the sizes used... spin, no, nsc, nnz = _siesta.read_tsde_sizes(self.file) - self._fortran_check("read_energy_density_matrix", "could not read energy density matrix sizes.") + self._fortran_check( + "read_energy_density_matrix", "could not read energy density matrix sizes." + ) ncol, col, dEDM = _siesta.read_tsde_edm(self.file, spin, no, nsc, nnz) - self._fortran_check("read_energy_density_matrix", "could not read energy density matrix.") + self._fortran_check( + "read_energy_density_matrix", "could not read energy density matrix." + ) # Try and immediately attach a geometry geom = kwargs.get("geometry", kwargs.get("geom", None)) @@ -542,12 +596,16 @@ def read_energy_density_matrix(self, **kwargs): geom.set_nsc(nsc) if geom.no != no: - raise SileError(f"{self!s}.read_energy_density_matrix could " - "not use the passed geometry as the number of atoms or orbitals " - "is inconsistent with DM file.") + raise SileError( + f"{self!s}.read_energy_density_matrix could " + "not use the passed geometry as the number of atoms or orbitals " + "is inconsistent with DM file." + ) # Create the energy density matrix container - EDM = EnergyDensityMatrix(geom, spin, nnzpr=1, dtype=np.float64, orthogonal=False) + EDM = EnergyDensityMatrix( + geom, spin, nnzpr=1, dtype=np.float64, orthogonal=False + ) # Create the new sparse matrix EDM._csr.ncol = ncol.astype(np.int32, copy=False) @@ -556,10 +614,10 @@ def read_energy_density_matrix(self, **kwargs): EDM._csr.col = col.astype(np.int32, copy=False) - 1 EDM._csr._nnz = len(col) - EDM._csr._D = _a.emptyd([nnz, spin+1]) + EDM._csr._D = _a.emptyd([nnz, spin + 1]) EDM._csr._D[:, :spin] = dEDM[:, :] * _Ry2eV # EDM file does not contain overlap matrix... so neglect it for now. - EDM._csr._D[:, spin] = 0. + EDM._csr._D[:, spin] = 0.0 _mat_spin_convert(EDM) @@ -567,15 +625,20 @@ def read_energy_density_matrix(self, **kwargs): if nsc[0] != 0 or geom.no_s >= col.max(): _csr_from_siesta(geom, EDM._csr) else: - warn(f"{self!s}.read_energy_density_matrix may result in a wrong sparse pattern!") + warn( + f"{self!s}.read_energy_density_matrix may result in a wrong sparse pattern!" + ) EDM = EDM.transpose(spin=False, sort=kwargs.get("sort", True)) - _add_overlap(EDM, kwargs.get("overlap", None), - f"{self.__class__.__name__}.read_energy_density_matrix") + _add_overlap( + EDM, + kwargs.get("overlap", None), + f"{self.__class__.__name__}.read_energy_density_matrix", + ) return EDM def read_fermi_level(self): - r""" Query the Fermi-level contained in the file + r"""Query the Fermi-level contained in the file Returns ------- @@ -585,8 +648,8 @@ def read_fermi_level(self): self._fortran_check("read_fermi_level", "could not read fermi-level.") return Ef - def write_density_matrices(self, DM, EDM, Ef=0., **kwargs): - r""" Writes the density matrix to a siesta.DM file + def write_density_matrices(self, DM, EDM, Ef=0.0, **kwargs): + r"""Writes the density matrix to a siesta.DM file Parameters ---------- @@ -603,8 +666,10 @@ def write_density_matrices(self, DM, EDM, Ef=0., **kwargs): EDMcsr.align(DMcsr) if DMcsr.nnz == 0: - raise SileError(f"{self!s}.write_density_matrices cannot write " - "a zero element sparse matrix!") + raise SileError( + f"{self!s}.write_density_matrices cannot write " + "a zero element sparse matrix!" + ) _csr_to_siesta(DM.geometry, DMcsr) _csr_to_siesta(DM.geometry, EDMcsr) @@ -615,31 +680,42 @@ def write_density_matrices(self, DM, EDM, Ef=0., **kwargs): _mat_spin_convert(EDMcsr, EDM.spin) # Ensure everything is correct - if not (np.allclose(DMcsr.ncol, EDMcsr.ncol) and - np.allclose(DMcsr.col, EDMcsr.col)): - raise ValueError(f"{self!s}.write_density_matrices got non compatible " - "DM and EDM matrices.") + if not ( + np.allclose(DMcsr.ncol, EDMcsr.ncol) and np.allclose(DMcsr.col, EDMcsr.col) + ): + raise ValueError( + f"{self!s}.write_density_matrices got non compatible " + "DM and EDM matrices." + ) if DM.orthogonal: dm = DMcsr._D else: - dm = DMcsr._D[:, :DM.S_idx] + dm = DMcsr._D[:, : DM.S_idx] if EDM.orthogonal: edm = EDMcsr._D else: - edm = EDMcsr._D[:, :EDM.S_idx] + edm = EDMcsr._D[:, : EDM.S_idx] nsc = DM.geometry.lattice.nsc.astype(np.int32) - _siesta.write_tsde_dm_edm(self.file, nsc, DMcsr.ncol, DMcsr.col + 1, - _toF(dm, np.float64), - _toF(edm, np.float64, _eV2Ry), Ef * _eV2Ry) - self._fortran_check("write_density_matrices", "could not write DM + EDM matrices.") + _siesta.write_tsde_dm_edm( + self.file, + nsc, + DMcsr.ncol, + DMcsr.col + 1, + _toF(dm, np.float64), + _toF(edm, np.float64, _eV2Ry), + Ef * _eV2Ry, + ) + self._fortran_check( + "write_density_matrices", "could not write DM + EDM matrices." + ) @set_module("sisl.io.siesta") class hsxSileSiesta(SileBinSiesta): - """ Hamiltonian and overlap matrix file + """Hamiltonian and overlap matrix file This file does not contain all information regarding the system. @@ -657,11 +733,11 @@ class hsxSileSiesta(SileBinSiesta): @property def version(self) -> int: - """ The version of the file """ + """The version of the file""" return _siesta.read_hsx_version(self.file) def _xij2system(self, xij, geometry=None): - """ Create a new geometry with *correct* nsc and somewhat correct xyz + """Create a new geometry with *correct* nsc and somewhat correct xyz Parameters ---------- @@ -670,6 +746,7 @@ def _xij2system(self, xij, geometry=None): geometry : Geometry, optional passed geometry """ + def get_geom_handle(xij): atoms = self._read_atoms() if not atoms is None: @@ -684,14 +761,14 @@ def get_geom_handle(xij): # Parse xij to correct geometry # first figure out all zeros (i.e. self-atom-orbitals) - idx0 = np.isclose(np.fabs(xij._D).sum(axis=1), 0.).nonzero()[0] + idx0 = np.isclose(np.fabs(xij._D).sum(axis=1), 0.0).nonzero()[0] row0 = row[idx0] # convert row0 and col0 to a first attempt of "atomization" atoms = [] for r in range(N): idx0r = (row0 == r).nonzero()[0] - #row0r = row0[idx0r] + # row0r = row0[idx0r] # although xij == 0, we just do % to ensure unit-cell orbs col0r = col[idx0[idx0r]] % N if np.all(col0r >= r): @@ -703,12 +780,14 @@ def get_geom_handle(xij): # convert list of orbitals to lists atoms = [list(a) for a in atoms] if sum(map(len, atoms)) != len(xij): - raise ValueError(f"{self.__class__.__name__} could not determine correct " - "number of orbitals.") + raise ValueError( + f"{self.__class__.__name__} could not determine correct " + "number of orbitals." + ) - atms = Atoms(Atom("H", [-1. for _ in atoms[0]])) + atms = Atoms(Atom("H", [-1.0 for _ in atoms[0]])) for orbs in atoms[1:]: - atms.append(Atom("H", [-1. for _ in orbs])) + atms.append(Atom("H", [-1.0 for _ in orbs])) return Geometry(np.zeros([len(atoms), 3]), atms) geom_handle = get_geom_handle(xij) @@ -732,13 +811,16 @@ def convert_to_atom(geom, xij): del idx # Now figure out if xij is consistent - duplicates = np.logical_and(np.diff(acol) == 0, - np.diff(arow) == 0).nonzero()[0] + duplicates = np.logical_and( + np.diff(acol) == 0, np.diff(arow) == 0 + ).nonzero()[0] if duplicates.size > 0: - if not np.allclose(xij[duplicates+1] - xij[duplicates], 0.): - raise ValueError(f"{self.__class__.__name__} converting xij(orb) -> xij(atom) went wrong. " - "This may happen if your coordinates are not inside the unitcell, please pass " - "a usable geometry.") + if not np.allclose(xij[duplicates + 1] - xij[duplicates], 0.0): + raise ValueError( + f"{self.__class__.__name__} converting xij(orb) -> xij(atom) went wrong. " + "This may happen if your coordinates are not inside the unitcell, please pass " + "a usable geometry." + ) # remove duplicates to create new matrix arow = np.delete(arow, duplicates) @@ -747,8 +829,11 @@ def convert_to_atom(geom, xij): # Create a new sparse matrix # Create the new index pointer - indptr = np.insert(np.array([0, len(xij)], np.int32), 1, - (np.diff(arow) != 0).nonzero()[0] + 1) + indptr = np.insert( + np.array([0, len(xij)], np.int32), + 1, + (np.diff(arow) != 0).nonzero()[0] + 1, + ) assert len(indptr) == geom.na + 1 return SparseCSR((xij, acol, indptr), shape=(geom.na, geom.na * n_s)) @@ -780,14 +865,19 @@ def coord_from_xij(xij): # check that everything is correct if (~idx).sum() > 0: neg_neighbours = neighbours[~idx] - if not np.allclose(xyz[neg_neighbours, :], - xij[atm, neg_neighbours] - xyz_atm): - raise ValueError(f"{self.__class__.__name__} xij(orb) -> xyz did not " - f"find same coordinates for different connections") + if not np.allclose( + xyz[neg_neighbours, :], xij[atm, neg_neighbours] - xyz_atm + ): + raise ValueError( + f"{self.__class__.__name__} xij(orb) -> xyz did not " + f"find same coordinates for different connections" + ) if mark.sum() != na: - raise ValueError(f"{self.__class__.__name__} xij(orb) -> Geometry does not " - f"have a fully connected geometry. It is impossible to create relative coordinates") + raise ValueError( + f"{self.__class__.__name__} xij(orb) -> Geometry does not " + f"have a fully connected geometry. It is impossible to create relative coordinates" + ) return xyz def sc_from_xij(xij, xyz): @@ -795,7 +885,7 @@ def sc_from_xij(xij, xyz): n_s = xij.shape[1] // xij.shape[0] if n_s == 1: # easy!! - return Lattice(xyz.max(axis=0) - xyz.min(axis=0) + 10., nsc=[1] * 3) + return Lattice(xyz.max(axis=0) - xyz.min(axis=0) + 10.0, nsc=[1] * 3) sc_off = _a.zerosd([n_s, 3]) mark = _a.zerosi(n_s) @@ -810,8 +900,10 @@ def sc_from_xij(xij, xyz): idx = mark[neighbour_isc] == 0 if not np.allclose(off[~idx], sc_off[neighbour_isc[~idx]]): - raise ValueError(f"{self.__class__.__name__} xij(orb) -> xyz did not " - f"find same supercell offsets for different connections") + raise ValueError( + f"{self.__class__.__name__} xij(orb) -> xyz did not " + f"find same supercell offsets for different connections" + ) if idx.sum() == 0: continue @@ -822,8 +914,10 @@ def sc_from_xij(xij, xyz): mark[nidx] = 1 sc_off[nidx] = off[idx] elif not np.allclose(sc_off[nidx], off[idx]): - raise ValueError(f"{self.__class__.__name__} xij(orb) -> xyz did not " - f"find same supercell offsets for different connections") + raise ValueError( + f"{self.__class__.__name__} xij(orb) -> xyz did not " + f"find same supercell offsets for different connections" + ) # We know that siesta returns isc # for iz in [0, 1, 2, 3, -3, -2, -1]: # for iy in [0, 1, 2, -2, -1]: @@ -832,10 +926,10 @@ def sc_from_xij(xij, xyz): # Note the first is always [0, 0, 0] # So our best chance is to *guess* the first nsc # then reshape, then guess, then reshape, then guess :) - #sc_diff = np.diff(sc_off, axis=0) + # sc_diff = np.diff(sc_off, axis=0) def get_nsc(sc_off): - """ Determine nsc depending on the axis """ + """Determine nsc depending on the axis""" # correct the offsets ndim = sc_off.ndim @@ -844,8 +938,8 @@ def get_nsc(sc_off): # always select the 2nd one since that contains the offset # for the first isc [1, 0, 0] or [0, 1, 0] or [0, 0, 1] - sc_dir = sc_off[(1, ) + np.index_exp[0] * (ndim - 2)].reshape(1, 3) - norm2_sc_dir = (sc_dir ** 2).sum() + sc_dir = sc_off[(1,) + np.index_exp[0] * (ndim - 2)].reshape(1, 3) + norm2_sc_dir = (sc_dir**2).sum() # figure out the maximum integer part # we select 0 indices for all already determined lattice # vectors since we know the first one is [0, 0, 0] @@ -877,9 +971,7 @@ def get_nsc(sc_off): # now determine cell parameters if all(nsc > 1): - cell = _a.arrayd([sc_off[0, 0, 1], - sc_off[0, 1, 0], - sc_off[1, 0, 0]]) + cell = _a.arrayd([sc_off[0, 0, 1], sc_off[0, 1, 0], sc_off[1, 0, 0]]) else: # we will never have all(nsc == 1) since that is # taken care of at the start @@ -904,7 +996,9 @@ def get_nsc(sc_off): # figure out which Cartesian direction we are *missing* cart_dir = np.argmin(lcell) - cell[i, cart_dir] = xyz[:, cart_dir].max() - xyz[:, cart_dir].min() + 10. + cell[i, cart_dir] = ( + xyz[:, cart_dir].max() - xyz[:, cart_dir].min() + 10.0 + ) i += 1 return Lattice(cell, nsc) @@ -917,7 +1011,7 @@ def get_nsc(sc_off): geometry = Geometry(xyz, geom_handle.atoms, lattice) # Move coordinates into unit-cell - geometry.xyz[:, :] = (geometry.fxyz % 1.) @ geometry.cell + geometry.xyz[:, :] = (geometry.fxyz % 1.0) @ geometry.cell else: if geometry.n_s != xij.shape[1] // xij.shape[0]: @@ -928,11 +1022,14 @@ def get_nsc(sc_off): def conv(orbs, atm): if len(orbs) == len(atm): return atm - return atm.copy(orbitals=[-1. for _ in orbs]) + return atm.copy(orbitals=[-1.0 for _ in orbs]) + atms = Atoms(list(map(conv, geom_handle.atoms, geometry.atoms))) if len(atms) != len(geometry): - raise ValueError(f"{self.__class__.__name__} passed geometry for reading " - "sparse matrix does not contain same number of atoms!") + raise ValueError( + f"{self.__class__.__name__} passed geometry for reading " + "sparse matrix does not contain same number of atoms!" + ) geometry = geometry.copy() # TODO check that geometry and xyz are the same! geometry._atoms = atms @@ -940,17 +1037,19 @@ def conv(orbs, atm): return geometry def _read_atoms(self, **kwargs): - """ Reads basis set and geometry information from the HSX file """ + """Reads basis set and geometry information from the HSX file""" # Now read the sizes used... no, na, nspecies = _siesta.read_hsx_specie_sizes(self.file) self._fortran_check("read_geometry", "could not read specie sizes.") # Read specie information - labels, val_q, norbs, isa = _siesta.read_hsx_species(self.file, nspecies, no, na) + labels, val_q, norbs, isa = _siesta.read_hsx_species( + self.file, nspecies, no, na + ) # convert to proper string labels = labels.T.reshape(nspecies, -1) labels = labels.view(f"S{labels.shape[1]}") - labels = list(map(lambda s: b''.join(s).decode("utf-8").strip(), - labels.tolist()) + labels = list( + map(lambda s: b"".join(s).decode("utf-8").strip(), labels.tolist()) ) self._fortran_check("read_geometry", "could not read species.") # to python index @@ -990,7 +1089,7 @@ def _read_atoms(self, **kwargs): # Read in orbital information atoms = [] for ispecie in range(nspecies): - n_l_zeta = _siesta.read_hsx_specie(self.file, ispecie+1, norbs[ispecie]) + n_l_zeta = _siesta.read_hsx_specie(self.file, ispecie + 1, norbs[ispecie]) self._fortran_check("read_geometry", f"could not read specie {ispecie}.") # create orbital # since the n, l, zeta is unique per atomic orbital (before expanding to @@ -1002,7 +1101,7 @@ def _read_atoms(self, **kwargs): if old_values != (n, l, zeta): old_values = (n, l, zeta) m = -l - orbs.append(AtomicOrbital(n=n, l=l, m=m, zeta=zeta, R=-1.)) + orbs.append(AtomicOrbital(n=n, l=l, m=m, zeta=zeta, R=-1.0)) m += 1 # now create atom @@ -1014,7 +1113,7 @@ def _read_atoms(self, **kwargs): return atoms def _r_geometry_v0(self, **kwargs): - """ Read the geometry from the old file version """ + """Read the geometry from the old file version""" spin, _, no, no_s, nnz = _siesta.read_hsx_sizes(self.file) self._fortran_check("read_geometry", "could not read geometry sizes.") ncol, col, _, _, dxij = _siesta.read_hsx_hsx0(self.file, spin, no, no_s, nnz) @@ -1039,7 +1138,7 @@ def _r_geometry_v1(self, **kwargs): return Geometry(xa.T * _Bohr2Ang, atoms, lattice=lattice) def read_geometry(self, **kwargs): - """ Read the geometry from the file + """Read the geometry from the file This will always work on new files Siesta >=5, but only sometimes on older versions of the HSX file format. @@ -1048,12 +1147,14 @@ def read_geometry(self, **kwargs): return getattr(self, f"_r_geometry_v{version}")(**kwargs) def read_fermi_level(self, **kwargs): - """ Reads the fermi level in the file + """Reads the fermi level in the file Only valid for files created by Siesta >=5. """ Ef = _siesta.read_hsx_ef(self.file) - msg = self._fortran_check("read_fermi_level", "could not read Fermi-level", ret_msg=True) + msg = self._fortran_check( + "read_fermi_level", "could not read Fermi-level", ret_msg=True + ) if msg: warn(msg) return Ef * _Ry2eV @@ -1069,9 +1170,11 @@ def _r_hamiltonian_v0(self, **kwargs): self._fortran_check("read_hamiltonian", "could not read Hamiltonian.") if geom.no != no or geom.no_s != no_s: - raise SileError(f"{self!s}.read_hamiltonian could not use the " - "passed geometry as the number of atoms or orbitals is " - "inconsistent with HSX file.") + raise SileError( + f"{self!s}.read_hamiltonian could not use the " + "passed geometry as the number of atoms or orbitals is " + "inconsistent with HSX file." + ) # Create the Hamiltonian container H = Hamiltonian(geom, spin, nnzpr=1, dtype=np.float32, orthogonal=False) @@ -1083,7 +1186,7 @@ def _r_hamiltonian_v0(self, **kwargs): H._csr.col = col.astype(np.int32, copy=False) H._csr._nnz = len(col) - H._csr._D = _a.empty([nnz, spin+1], dtype=dH.dtype) + H._csr._D = _a.empty([nnz, spin + 1], dtype=dH.dtype) H._csr._D[:, :spin] = dH[:, :] * _Ry2eV H._csr._D[:, spin] = dS[:] @@ -1106,9 +1209,11 @@ def _r_hamiltonian_v1(self, **kwargs): self._fortran_check("read_hamiltonian", "could not read Hamiltonian.") if geom.no != no or geom.no_s != no_s: - raise SileError(f"{self!s}.read_hamiltonian could not use the " - "passed geometry as the number of atoms or orbitals is " - "inconsistent with HSX file.") + raise SileError( + f"{self!s}.read_hamiltonian could not use the " + "passed geometry as the number of atoms or orbitals is " + "inconsistent with HSX file." + ) # Create the Hamiltonian container H = Hamiltonian(geom, spin, nnzpr=1, dtype=np.float32, orthogonal=False) @@ -1121,7 +1226,7 @@ def _r_hamiltonian_v1(self, **kwargs): H._csr.col = col.astype(np.int32, copy=False) H._csr._nnz = len(col) - H._csr._D = _a.empty([nnz, spin+1], dtype=dH.dtype) + H._csr._D = _a.empty([nnz, spin + 1], dtype=dH.dtype) H._csr._D[:, :spin] = dH[:, :] * _Ry2eV H._csr._D[:, spin] = dS[:] @@ -1133,12 +1238,12 @@ def _r_hamiltonian_v1(self, **kwargs): return H.transpose(spin=False, sort=kwargs.get("sort", True)) def read_hamiltonian(self, **kwargs): - """ Returns the electronic structure from the siesta.TSHS file """ + """Returns the electronic structure from the siesta.TSHS file""" version = _siesta.read_hsx_version(self.file) return getattr(self, f"_r_hamiltonian_v{version}")(**kwargs) def _r_overlap_v0(self, **kwargs): - """ Returns the overlap matrix from the siesta.HSX file """ + """Returns the overlap matrix from the siesta.HSX file""" geom = self.read_geometry(**kwargs) # Now read the sizes used... @@ -1149,9 +1254,11 @@ def _r_overlap_v0(self, **kwargs): self._fortran_check("read_overlap", "could not read overlap matrix.") if geom.no != no or geom.no_s != no_s: - raise SileError(f"{self!s}.read_overlap could not use the " - "passed geometry as the number of atoms or orbitals is " - "inconsistent with HSX file.") + raise SileError( + f"{self!s}.read_overlap could not use the " + "passed geometry as the number of atoms or orbitals is " + "inconsistent with HSX file." + ) # Create the Hamiltonian container S = Overlap(geom, nnzpr=1, dtype=np.float32) @@ -1174,7 +1281,7 @@ def _r_overlap_v0(self, **kwargs): return S.transpose(sort=kwargs.get("sort", True)) def _r_overlap_v1(self, **kwargs): - """ Returns the overlap matrix from the siesta.HSX file """ + """Returns the overlap matrix from the siesta.HSX file""" geom = self.read_geometry(**kwargs) # Now read the sizes used... @@ -1185,9 +1292,11 @@ def _r_overlap_v1(self, **kwargs): self._fortran_check("read_overlap", "could not read overlap matrix.") if geom.no != no or geom.no_s != no_s: - raise SileError(f"{self!s}.read_overlap could not use the " - "passed geometry as the number of atoms or orbitals is " - "inconsistent with HSX file.") + raise SileError( + f"{self!s}.read_overlap could not use the " + "passed geometry as the number of atoms or orbitals is " + "inconsistent with HSX file." + ) # Create the Hamiltonian container S = Overlap(geom, nnzpr=1) @@ -1208,14 +1317,14 @@ def _r_overlap_v1(self, **kwargs): return S.transpose(sort=kwargs.get("sort", True)) def read_overlap(self, **kwargs): - """ Returns the electronic structure from the siesta.TSHS file """ + """Returns the electronic structure from the siesta.TSHS file""" version = _siesta.read_hsx_version(self.file) return getattr(self, f"_r_overlap_v{version}")(**kwargs) @set_module("sisl.io.siesta") class wfsxSileSiesta(SileBinSiesta): - r""" Binary WFSX file reader for Siesta + r"""Binary WFSX file reader for Siesta The WFSX file assumes that users initialize the object with a `parent` argument (or one of the other geometry related objects as @@ -1239,7 +1348,7 @@ class wfsxSileSiesta(SileBinSiesta): """ def _setup(self, *args, **kwargs): - """ Simple setup that needs to be overwritten + """Simple setup that needs to be overwritten All _read_next_* methods expect the fortran file unit to be handled and that the position in the file is correct. @@ -1266,7 +1375,9 @@ def _setup(self, *args, **kwargs): lattice = kwargs.get("lattice", kwargs.get("sc", lattice)) if lattice is None and geometry is not None: - raise ValueError(f"{self.__class__.__name__}(geometry=Geometry, lattice=None) is not an allowed argument combination.") + raise ValueError( + f"{self.__class__.__name__}(geometry=Geometry, lattice=None) is not an allowed argument combination." + ) if parent is None: parent = geometry @@ -1277,14 +1388,20 @@ def _setup(self, *args, **kwargs): self._geometry = geometry self._lattice = lattice if self._parent is None and self._geometry is None and self._lattice is None: + def conv(k): - if not np.allclose(k, 0.): - warn(f"{self.__class__.__name__} cannot convert stored k-points from 1/Ang to reduced coordinates. " - "Please ensure parent=Hamiltonian, geometry=Geometry, or lattice=Lattice to ensure reduced k.") + if not np.allclose(k, 0.0): + warn( + f"{self.__class__.__name__} cannot convert stored k-points from 1/Ang to reduced coordinates. " + "Please ensure parent=Hamiltonian, geometry=Geometry, or lattice=Lattice to ensure reduced k." + ) return k / _Bohr2Ang + else: + def conv(k): return (k @ lattice.cell.T) / (2 * np.pi * _Bohr2Ang) + self._convert_k = conv def _open_wfsx(self, mode, rewind=False): @@ -1342,7 +1459,7 @@ def _setup_parsing(self, close=True, skip_basis=True): skip_basis : bool, optional whether to also read the basis or not """ - self._open_wfsx('r') + self._open_wfsx("r") # Read the sizes relevant to the file. # We also read whether there's only gamma point information or there are multiple points self._sizes = self._read_next_sizes(skip_basis=skip_basis) @@ -1363,7 +1480,7 @@ def _setup_parsing(self, close=True, skip_basis=True): Funcs = namedtuple("WFSXReads", ["read_index", "read_next"]) self._funcs = Funcs( getattr(_siesta, f"read_wfsx_index_{func_index}"), - getattr(_siesta, f"read_wfsx_next_{func_index}") + getattr(_siesta, f"read_wfsx_next_{func_index}"), ) if close: @@ -1387,7 +1504,9 @@ def _read_next_sizes(self, skip_basis=False): """ # Check that we are in the right position in the file if self._state != -1: - raise SileError(f"We are not in a position to read the sizes. State is: {self._state}") + raise SileError( + f"We are not in a position to read the sizes. State is: {self._state}" + ) # Read the sizes that we can find in the WFSX file Sizes = namedtuple("Sizes", ["nspin", "no_u", "nk", "Gamma"]) sizes = _siesta.read_wfsx_next_sizes(self._iu, skip_basis) @@ -1408,7 +1527,9 @@ def _read_next_basis(self): """ # Check that we are in the right position in the file if self._state != 0: - raise SileError(f"We are not in a position to read the basis. State is: {self._state}") + raise SileError( + f"We are not in a position to read the basis. State is: {self._state}" + ) # Read the basis information that we can find in the WFSX file basis_info = _siesta.read_wfsx_next_basis(self._iu, self._sizes.no_u) # Inform that we are now in front of k point information @@ -1417,11 +1538,24 @@ def _read_next_basis(self): self._fortran_check("_read_basis", "could not read basis information") # Convert the information to a dict so that code is easier to follow. - basis_info = dict(zip(("atom_indices", "atom_labels", "orb_index_atom", "orb_n", "orb_symmetry"), basis_info)) + basis_info = dict( + zip( + ( + "atom_indices", + "atom_labels", + "orb_index_atom", + "orb_n", + "orb_symmetry", + ), + basis_info, + ) + ) # Sanitize the string information for char_key in ("atom_labels", "orb_symmetry"): - basis_info[char_key] = np.array(["".join(label).rstrip() for label in basis_info[char_key].astype(str)]) + basis_info[char_key] = np.array( + ["".join(label).rstrip() for label in basis_info[char_key].astype(str)] + ) # Find out the unique atom indices unique_ats = np.unique(basis_info["atom_indices"]) @@ -1433,7 +1567,10 @@ def _get_atom_object(at): orbitals = [ AtomicOrbital(f"{n}{symmetry}") - for n, symmetry in zip(basis_info["orb_n"][atom_orbs], basis_info["orb_symmetry"][atom_orbs]) + for n, symmetry in zip( + basis_info["orb_n"][atom_orbs], + basis_info["orb_symmetry"][atom_orbs], + ) ] return Atom(at_label, orbitals=orbitals) @@ -1469,20 +1606,27 @@ def _read_next_info(self, ispin, ik): self._ispin = ispin # Check that we are in a position where we will read state information if self._state != 1: - raise SileError(f"We are not in a position to read k point information. State is: {self._state}") + raise SileError( + f"We are not in a position to read k point information. State is: {self._state}" + ) # Read the state information file_ispin, file_ik, k, weight, nwf = _siesta.read_wfsx_next_info(self._iu) # Inform that we are now in front of state values self._state = 2 # Check that the read went fine - self._fortran_check("_read_next_info", f"could not read next eigenstate info [{ispin + 1}, {ik + 1}]") + self._fortran_check( + "_read_next_info", + f"could not read next eigenstate info [{ispin + 1}, {ik + 1}]", + ) # Check that the read indices match the indices that we were expecting. if file_ispin != ispin + 1 or file_ik != ik + 1: self._ik = file_ik - 1 self._ispin = file_ispin - 1 - raise SileError(f"WFSX indices do not match the expected ones. Expected: [{ispin + 1}, {ik + 1}], found [{file_ispin}, {file_ik}]") + raise SileError( + f"WFSX indices do not match the expected ones. Expected: [{ispin + 1}, {ik + 1}], found [{file_ispin}, {file_ik}]" + ) return k, weight, nwf @@ -1512,14 +1656,19 @@ def _read_next_values(self, ispin, ik, nwf): """ # Check that we are in the right position in the file if self._state != 2: - raise SileError(f"We are not in a position to read k point WFSX values. State is: {self._state}") + raise SileError( + f"We are not in a position to read k point WFSX values. State is: {self._state}" + ) # Read the state values idx, eig, state = self._funcs.read_next(self._iu, self._sizes.no_u, nwf) # Inform that we are now in front of the next state info self._state = 1 # Check that everything went fine - self._fortran_check("_read_next_values", f"could not read next eigenstate values [{ispin + 1}, {ik + 1}]") + self._fortran_check( + "_read_next_values", + f"could not read next eigenstate values [{ispin + 1}, {ik + 1}]", + ) return idx, eig, state @@ -1588,7 +1737,7 @@ def read_basis(self): return basis def yield_eigenstate(self): - r""" Iterates over the states in the WFSX file + r"""Iterates over the states in the WFSX file Yields ------ @@ -1642,7 +1791,9 @@ def read_eigenstate(self, k=(0, 0, 0), spin=0, ktol=1e-4): """ # Iterate over all eigenstates in the file for state in self.yield_eigenstate(): - if state.info.get("spin", 0) == spin and np.allclose(state.info["k"], k, atol=ktol): + if state.info.get("spin", 0) == spin and np.allclose( + state.info["k"], k, atol=ktol + ): # This is the state that the user requested return state return None @@ -1664,7 +1815,9 @@ def read_info(self): # Check if we are in the correct position in the file (we should be just after the header) if self._state != 1: - raise ValueError(f"We are not in a position to read eigenstate info in the file. State: {self._state}") + raise ValueError( + f"We are not in a position to read eigenstate info in the file. State: {self._state}" + ) if self._sizes.nspin == 2: nspin = 2 @@ -1681,21 +1834,21 @@ def read_info(self): return self._convert_k(k), kw, nwf def read_brillouinzone(self): - """ Read the brillouin zone object """ + """Read the brillouin zone object""" k, weight, _ = self.read_info() return BrillouinZone(self._parent, k=k, weight=weight) @set_module("sisl.io.siesta") class _gridSileSiesta(SileBinSiesta): - r""" Binary real-space grid file + r"""Binary real-space grid file The Siesta binary grid sile will automatically convert the units from Siesta units (Bohr, Ry) to sisl units (Ang, eV) provided the correct extension is present. """ def read_lattice(self, *args, **kwargs): - r""" Return the cell contained in the file """ + r"""Return the cell contained in the file""" cell = _siesta.read_grid_cell(self.file).T * _Bohr2Ang self._fortran_check("read_lattice", "could not read cell.") @@ -1703,7 +1856,7 @@ def read_lattice(self, *args, **kwargs): return Lattice(cell) def read_grid_size(self): - r""" Query grid size information such as the grid size and number of spin components + r"""Query grid size information such as the grid size and number of spin components Returns ------- @@ -1717,7 +1870,7 @@ def read_grid_size(self): return nspin, mesh def read_grid(self, index=0, dtype=np.float64, *args, **kwargs): - """ Read grid contained in the Grid file + """Read grid contained in the Grid file Parameters ---------- @@ -1754,7 +1907,7 @@ def read_grid(self, index=0, dtype=np.float64, *args, **kwargs): @set_module("sisl.io.siesta") class _gfSileSiesta(SileBinSiesta): - """ Surface Green function file containing, Hamiltonian, overlap matrix and self-energies + """Surface Green function file containing, Hamiltonian, overlap matrix and self-energies Do not mix read and write statements when using this code. Complete one or the other before doing the other thing. Fortran does not allow the same file opened twice, if this @@ -1782,7 +1935,7 @@ class _gfSileSiesta(SileBinSiesta): """ def _setup(self, *args, **kwargs): - """ Simple setup that needs to be overwritten """ + """Simple setup that needs to be overwritten""" super()._setup(*args, **kwargs) # The unit convention used for energy-points @@ -1837,26 +1990,30 @@ def _close_gf(self): pass def _step_counter(self, method, **kwargs): - """ Method for stepping values *must* be called before doing the actual read to check correct values """ + """Method for stepping values *must* be called before doing the actual read to check correct values""" opt = {"method": method} if kwargs.get("header", False): # The header only exists once, so check whether it is the correct place to read/write if self._state != -1 or self._is_read == 1: - raise SileError(f"{self.__class__.__name__}.{method} failed because the header has already " - "been read.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because the header has already " + "been read." + ) self._state = -1 self._ispin = 0 self._ik = 0 self._iE = 0 - #print("HEADER: ", self._state, self._ispin, self._ik, self._iE) + # print("HEADER: ", self._state, self._ispin, self._ik, self._iE) elif kwargs.get("HS", False): # Correct for the previous state and jump values if self._state == -1: # We have just read the header if self._is_read != 1: - raise SileError(f"{self.__class__.__name__}.{method} failed because the file descriptor " - "has not read the header.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because the file descriptor " + "has not read the header." + ) # Reset values as though the header has just been read self._state = 0 self._ispin = 0 @@ -1865,13 +2022,17 @@ def _step_counter(self, method, **kwargs): elif self._state == 0: if self._is_read == 1: - raise SileError(f"{self.__class__.__name__}.{method} failed because the file descriptor " - "has already read the current HS for the given k-point.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because the file descriptor " + "has already read the current HS for the given k-point." + ) elif self._state == 1: # We have just read from the last energy-point if self._iE + 1 != self._nE or self._is_read != 1: - raise SileError(f"{self.__class__.__name__}.{method} failed because the file descriptor " - "has not read all energy-points for a given k-point.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because the file descriptor " + "has not read all energy-points for a given k-point." + ) self._state = 0 self._ik += 1 if self._ik >= self._nk: @@ -1880,19 +2041,23 @@ def _step_counter(self, method, **kwargs): self._ik = 0 self._iE = 0 - #print("HS: ", self._state, self._ispin, self._ik, self._iE) + # print("HS: ", self._state, self._ispin, self._ik, self._iE) if self._ispin >= self._nspin: opt["spin"] = self._ispin + 1 opt["nspin"] = self._nspin - raise SileError(f"{self.__class__.__name__}.{method} failed because of missing information, " - "a non-existing entry has been requested! spin={spin} max_spin={nspin}.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because of missing information, " + "a non-existing entry has been requested! spin={spin} max_spin={nspin}." + ) else: # We are reading an energy-point if self._state == -1: - raise SileError(f"{self.__class__.__name__}.{method} failed because the file descriptor " - "has an unknown state.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because the file descriptor " + "has an unknown state." + ) elif self._state == 0: if self._is_read == 1: @@ -1900,8 +2065,10 @@ def _step_counter(self, method, **kwargs): self._state = 1 self._iE = 0 else: - raise SileError(f"{self.__class__.__name__}.{method} failed because the file descriptor " - "has an unknown state.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because the file descriptor " + "has an unknown state." + ) elif self._state == 1: if self._is_read == 0 and self._iE < self._nE: @@ -1910,16 +2077,20 @@ def _step_counter(self, method, **kwargs): elif self._is_read == 1 and self._iE + 1 < self._nE: self._iE += 1 else: - raise SileError(f"{self.__class__.__name__}.{method} failed because the file descriptor " - "has an unknown state.") + raise SileError( + f"{self.__class__.__name__}.{method} failed because the file descriptor " + "has an unknown state." + ) if self._iE >= self._nE: # You are trying to read beyond the entry opt["iE"] = self._iE + 1 opt["NE"] = self._nE - raise SileError(f"{self.__class__.__name__}.{method} failed because of missing information, " - f"a non-existing energy-point has been requested! E_index={self._iE+1} max_E_index={self._nE}.") - #print("SE: ", self._state, self._ispin, self._ik, self._iE) + raise SileError( + f"{self.__class__.__name__}.{method} failed because of missing information, " + f"a non-existing energy-point has been requested! E_index={self._iE+1} max_E_index={self._nE}." + ) + # print("SE: ", self._state, self._ispin, self._ik, self._iE) # Always signal (when stepping) that we have not yet read the thing if kwargs.get("read", False): @@ -1928,7 +2099,7 @@ def _step_counter(self, method, **kwargs): self._is_read = 0 def Eindex(self, E): - """ Return the closest energy index corresponding to the energy ``E`` + """Return the closest energy index corresponding to the energy ``E`` Parameters ---------- @@ -1941,15 +2112,21 @@ def Eindex(self, E): idxE = np.abs(self._E - E).argmin() ret_E = self._E[idxE] if abs(ret_E - E) > 5e-3: - warn(self.__class__.__name__ + " requesting energy " + - f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!") + warn( + self.__class__.__name__ + + " requesting energy " + + f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!" + ) elif abs(ret_E - E) > 1e-3: - info(self.__class__.__name__ + " requesting energy " + - f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!") + info( + self.__class__.__name__ + + " requesting energy " + + f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!" + ) return idxE def kindex(self, k): - """ Return the index of the k-point that is closests to the queried k-point (in reduced coordinates) + """Return the index of the k-point that is closests to the queried k-point (in reduced coordinates) Parameters ---------- @@ -1962,14 +2139,19 @@ def kindex(self, k): ik = np.sum(np.abs(self._k - _a.asarrayd(k)[None, :]), axis=1).argmin() ret_k = self._k[ik, :] if not np.allclose(ret_k, k, atol=0.0001): - warn(SileWarning(self.__class__.__name__ + " requesting k-point " + - "[{:.3f}, {:.3f}, {:.3f}]".format(*k) + - " found " + - "[{:.3f}, {:.3f}, {:.3f}]".format(*ret_k))) + warn( + SileWarning( + self.__class__.__name__ + + " requesting k-point " + + "[{:.3f}, {:.3f}, {:.3f}]".format(*k) + + " found " + + "[{:.3f}, {:.3f}, {:.3f}]".format(*ret_k) + ) + ) return ik def read_header(self): - """ Read the header of the file and open it for reading subsequently + """Read the header of the file and open it for reading subsequently NOTES: this method may change in the future @@ -1984,7 +2166,7 @@ def read_header(self): if self._fortran_is_open(): _siesta.io_m.rewind_file(self._iu) else: - self._open_gf('r') + self._open_gf("r") nspin, no_u, nkpt, NE = _siesta.read_gf_sizes(self._iu) self._fortran_check("read_header", "could not read sizes.") self._nspin = nspin @@ -1997,7 +2179,7 @@ def read_header(self): k, E = _siesta.read_gf_header(self._iu, nkpt, NE) self._fortran_check("read_header", "could not read header information.") - if self._nspin > 2: # non-colinear + if self._nspin > 2: # non-colinear self._no_u = no_u * 2 else: self._no_u = no_u @@ -2007,7 +2189,7 @@ def read_header(self): return nspin, no_u, self._k, self._E def disk_usage(self): - """ Calculate the estimated size of the resulting file + """Calculate the estimated size of the resulting file Returns ------- @@ -2025,7 +2207,7 @@ def disk_usage(self): # no_u ** 2 = matrix size # 16 = bytes in double complex # 1024 ** 3 = B -> GB - mem = (HS + SE) * self._no_u ** 2 * 16 / 1024 ** 3 + mem = (HS + SE) * self._no_u**2 * 16 / 1024**3 if not is_open: self._close_gf() @@ -2033,7 +2215,7 @@ def disk_usage(self): return mem def read_hamiltonian(self): - """ Return current Hamiltonian and overlap matrix from the GF file + """Return current Hamiltonian and overlap matrix from the GF file Returns ------- @@ -2042,12 +2224,14 @@ def read_hamiltonian(self): """ self._step_counter("read_hamiltonian", HS=True, read=True) H, S = _siesta.read_gf_hs(self._iu, self._no_u) - self._fortran_check("read_hamiltonian", "could not read Hamiltonian and overlap matrices.") + self._fortran_check( + "read_hamiltonian", "could not read Hamiltonian and overlap matrices." + ) # we don't convert to C order! return H * _Ry2eV, S def read_self_energy(self): - r""" Read the currently reached bulk self-energy + r"""Read the currently reached bulk self-energy The returned self-energy is: @@ -2065,7 +2249,7 @@ def read_self_energy(self): return SE * _Ry2eV def HkSk(self, k=(0, 0, 0), spin=0): - """ Retrieve H and S for the given k-point + """Retrieve H and S for the given k-point Parameters ---------- @@ -2081,20 +2265,32 @@ def HkSk(self, k=(0, 0, 0), spin=0): # find k-index that is requested ik = self.kindex(k) - _siesta.read_gf_find(self._iu, self._nspin, self._nk, self._nE, - self._state, self._ispin, self._ik, self._iE, self._is_read, - 0, spin, ik, 0) + _siesta.read_gf_find( + self._iu, + self._nspin, + self._nk, + self._nE, + self._state, + self._ispin, + self._ik, + self._iE, + self._is_read, + 0, + spin, + ik, + 0, + ) self._fortran_check("HkSk", "could not find Hamiltonian and overlap matrix.") self._state = 0 self._ispin = spin self._ik = ik self._iE = 0 - self._is_read = 0 # signal this is to be read + self._is_read = 0 # signal this is to be read return self.read_hamiltonian() def self_energy(self, E, k=0, spin=0): - """ Retrieve self-energy for a given energy-point and k-point + """Retrieve self-energy for a given energy-point and k-point Parameters ---------- @@ -2110,20 +2306,32 @@ def self_energy(self, E, k=0, spin=0): ik = self.kindex(k) iE = self.Eindex(E) - _siesta.read_gf_find(self._iu, self._nspin, self._nk, self._nE, - self._state, self._ispin, self._ik, self._iE, self._is_read, - 1, spin, ik, iE) + _siesta.read_gf_find( + self._iu, + self._nspin, + self._nk, + self._nE, + self._state, + self._ispin, + self._ik, + self._iE, + self._is_read, + 1, + spin, + ik, + iE, + ) self._fortran_check("self_energy", "could not find requested self-energy.") self._state = 1 self._ispin = spin self._ik = ik self._iE = iE - self._is_read = 0 # signal this is to be read + self._is_read = 0 # signal this is to be read return self.read_self_energy() - def write_header(self, bz, E, mu=0., obj=None): - """ Write to the binary file the header of the file + def write_header(self, bz, E, mu=0.0, obj=None): + """Write to the binary file the header of the file Parameters ---------- @@ -2180,15 +2388,27 @@ def write_header(self, bz, E, mu=0., obj=None): # Now write to it... self._step_counter("write_header", header=True, read=True) # see onlysSileSiesta.read_lattice for .T - _siesta.write_gf_header(self._iu, nspin, _toF(cell.T, np.float64, 1. / _Bohr2Ang), - na_u, no_u, no_u, _toF(xa.T, np.float64, 1. / _Bohr2Ang), - lasto, bloch, 0, mu * _eV2Ry, _toF(k.T, np.float64), - w, self._E / self._E_Ry2eV, - **sizes) + _siesta.write_gf_header( + self._iu, + nspin, + _toF(cell.T, np.float64, 1.0 / _Bohr2Ang), + na_u, + no_u, + no_u, + _toF(xa.T, np.float64, 1.0 / _Bohr2Ang), + lasto, + bloch, + 0, + mu * _eV2Ry, + _toF(k.T, np.float64), + w, + self._E / self._E_Ry2eV, + **sizes, + ) self._fortran_check("write_header", "could not write header information.") def write_hamiltonian(self, H, S=None): - """ Write the current energy, k-point and H and S to the file + """Write the current energy, k-point and H and S to the file Parameters ---------- @@ -2202,13 +2422,20 @@ def write_hamiltonian(self, H, S=None): if S is None: S = np.eye(no, dtype=np.complex128, order="F") self._step_counter("write_hamiltonian", HS=True, read=True) - _siesta.write_gf_hs(self._iu, self._ik, self._E[self._iE] / self._E_Ry2eV, - _toF(H, np.complex128, _eV2Ry), - _toF(S, np.complex128), no_u=no) - self._fortran_check("write_hamiltonian", "could not write Hamiltonian and overlap matrices.") + _siesta.write_gf_hs( + self._iu, + self._ik, + self._E[self._iE] / self._E_Ry2eV, + _toF(H, np.complex128, _eV2Ry), + _toF(S, np.complex128), + no_u=no, + ) + self._fortran_check( + "write_hamiltonian", "could not write Hamiltonian and overlap matrices." + ) def write_self_energy(self, SE): - r""" Write the current self energy, k-point and H and S to the file + r"""Write the current self energy, k-point and H and S to the file The self-energy must correspond to the *bulk* self-energy @@ -2222,15 +2449,21 @@ def write_self_energy(self, SE): """ no = len(SE) self._step_counter("write_self_energy", read=True) - _siesta.write_gf_se(self._iu, self._ik, self._iE, self._E[self._iE] / self._E_Ry2eV, - _toF(SE, np.complex128, _eV2Ry), no_u=no) + _siesta.write_gf_se( + self._iu, + self._ik, + self._iE, + self._E[self._iE] / self._E_Ry2eV, + _toF(SE, np.complex128, _eV2Ry), + no_u=no, + ) self._fortran_check("write_self_energy", "could not write self-energy.") def __len__(self): return self._nE * self._nk * self._nspin def __iter__(self): - """ Iterate through the energies and k-points that this GF file is associated with + """Iterate through the energies and k-points that this GF file is associated with Yields ------ @@ -2266,11 +2499,12 @@ def _type(name, obj, dic=None): dic["__doc__"] = obj.__doc__.replace(obj.__name__, name) except Exception: pass - return type(name, (obj, ), dic) + return type(name, (obj,), dic) + # Faster than class ... \ pass tsgfSileSiesta = _type("tsgfSileSiesta", _gfSileSiesta) -gridSileSiesta = _type("gridSileSiesta", _gridSileSiesta, {"grid_unit": 1.}) +gridSileSiesta = _type("gridSileSiesta", _gridSileSiesta, {"grid_unit": 1.0}) if has_fortran_module: add_sile("TSHS", tshsSileSiesta) @@ -2281,20 +2515,61 @@ def _type(name, obj, dic=None): add_sile("TSGF", tsgfSileSiesta) add_sile("WFSX", wfsxSileSiesta) # These have unit-conversions - add_sile("RHO", _type("rhoSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("LDOS", _type("ldosSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("RHOINIT", _type("rhoinitSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("RHOXC", _type("rhoxcSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("DRHO", _type("drhoSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("BADER", _type("baderSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("IOCH", _type("iorhoSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("TOCH", _type("totalrhoSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) + add_sile( + "RHO", + _type("rhoSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3}), + ) + add_sile( + "LDOS", + _type("ldosSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3}), + ) + add_sile( + "RHOINIT", + _type( + "rhoinitSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3} + ), + ) + add_sile( + "RHOXC", + _type("rhoxcSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3}), + ) + add_sile( + "DRHO", + _type("drhoSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3}), + ) + add_sile( + "BADER", + _type("baderSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3}), + ) + add_sile( + "IOCH", + _type("iorhoSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3}), + ) + add_sile( + "TOCH", + _type( + "totalrhoSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3} + ), + ) # The following two files *require* that # STM.DensityUnits Ele/bohr**3 # which I can't check! # They are however the default - add_sile("STS", _type("stsSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) - add_sile("STM.LDOS", _type("stmldosSileSiesta", _gridSileSiesta, {"grid_unit": 1./_Bohr2Ang ** 3})) + add_sile( + "STS", + _type("stsSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3}), + ) + add_sile( + "STM.LDOS", + _type( + "stmldosSileSiesta", _gridSileSiesta, {"grid_unit": 1.0 / _Bohr2Ang**3} + ), + ) add_sile("VH", _type("hartreeSileSiesta", _gridSileSiesta, {"grid_unit": _Ry2eV})) - add_sile("VNA", _type("neutralatomhartreeSileSiesta", _gridSileSiesta, {"grid_unit": _Ry2eV})) - add_sile("VT", _type("totalhartreeSileSiesta", _gridSileSiesta, {"grid_unit": _Ry2eV})) + add_sile( + "VNA", + _type("neutralatomhartreeSileSiesta", _gridSileSiesta, {"grid_unit": _Ry2eV}), + ) + add_sile( + "VT", _type("totalhartreeSileSiesta", _gridSileSiesta, {"grid_unit": _Ry2eV}) + ) diff --git a/src/sisl/io/siesta/eig.py b/src/sisl/io/siesta/eig.py index 1e70b5bdb7..370258fa18 100644 --- a/src/sisl/io/siesta/eig.py +++ b/src/sisl/io/siesta/eig.py @@ -18,7 +18,7 @@ @set_module("sisl.io.siesta") class eigSileSiesta(SileSiesta): - """ Eigenvalues as calculated in the SCF loop, easy plots using `sdata` + """Eigenvalues as calculated in the SCF loop, easy plots using `sdata` The .EIG file from Siesta contains the eigenvalues for k-points used during the SCF. Using the command-line utility `sdata` one may plot the eigenvalue spectrum to visualize the @@ -75,7 +75,7 @@ class eigSileSiesta(SileSiesta): @sile_fh_open(True) def read_fermi_level(self): - r""" Query the Fermi-level contained in the file + r"""Query the Fermi-level contained in the file Returns ------- @@ -85,7 +85,7 @@ def read_fermi_level(self): @sile_fh_open() def read_data(self): - r""" Read eigenvalues, as calculated and written by Siesta + r"""Read eigenvalues, as calculated and written by Siesta Returns ------- @@ -106,6 +106,7 @@ def read_data(self): eigs = np.empty([ns, nk, nb], np.float64) readline = self.readline + def iterE(size): ne = 0 out = readline().split()[1:] @@ -125,9 +126,9 @@ def iterE(size): @default_ArgumentParser(description="Manipulate Siesta EIG file.") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ - #limit_args = kwargs.get("limit_arguments", True) - #short = kwargs.get("short", False) + """Returns the arguments that is available for this Sile""" + # limit_args = kwargs.get("limit_arguments", True) + # short = kwargs.get("short", False) # We limit the import to occur here import argparse @@ -147,41 +148,59 @@ def ArgumentParser(self, p=None, *args, **kwargs): # default T = 300 k units("K", "eV") * 300, # distribution type - "gaussian" + "gaussian", ], } try: - d["_weights"] = kpSileSiesta(str(self.file).replace("EIG", "KP")).read_data()[1] + d["_weights"] = kpSileSiesta( + str(self.file).replace("EIG", "KP") + ).read_data()[1] except Exception: d["_weights"] = None namespace = default_namespace(**d) # Energy grabs class ERange(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): ns._Emap = strmap(float, value)[0] - p.add_argument("--energy", "-E", - action=ERange, - help="Denote the sub-section of energies that are plotted: '-1:0,1:2' [eV]") + + p.add_argument( + "--energy", + "-E", + action=ERange, + help="Denote the sub-section of energies that are plotted: '-1:0,1:2' [eV]", + ) # k-point weights class KP(argparse.Action): def __call__(self, parser, ns, value, option_string=None): ns._weights = kpSileSiesta(value[0]).read_data()[1] - p.add_argument("--kp-file", "-kp", nargs=1, metavar="FILE", action=KP, - help="The k-point file from which to read the band-weights (only applicable to --dos option)") + + p.add_argument( + "--kp-file", + "-kp", + nargs=1, + metavar="FILE", + action=KP, + help="The k-point file from which to read the band-weights (only applicable to --dos option)", + ) # Energy grabs class DOS(argparse.Action): - def __call__(dos_self, parser, ns, values, option_string=None): # pylint: disable=E0213 + def __call__( + dos_self, parser, ns, values, option_string=None + ): # pylint: disable=E0213 if getattr(ns, "_weights", None) is None: if ns._eigs.shape[1] > 1: - raise ValueError("Can not calculate DOS when k-point weights are unknown, please pass -kp before this command") + raise ValueError( + "Can not calculate DOS when k-point weights are unknown, please pass -kp before this command" + ) if len(ns._weights) != ns._eigs.shape[1]: - raise SileError(f"{self!s} --dos the number of k-points for the eigenvalues and k-point weights " - "are different, please use -kp before --dos.") + raise SileError( + f"{self!s} --dos the number of k-points for the eigenvalues and k-point weights " + "are different, please use -kp before --dos." + ) # Specify default settings dE = ns._dos_args[0] @@ -199,7 +218,9 @@ def __call__(dos_self, parser, ns, values, option_string=None): # pylint: disabl elif i == 2: distribution = value else: - raise ValueError(f"Too many values passed? Unknown value {value}?") + raise ValueError( + f"Too many values passed? Unknown value {value}?" + ) elif n_eq == len(values): for key, val in map(lambda x: x.split("="), values): @@ -211,10 +232,14 @@ def __call__(dos_self, parser, ns, values, option_string=None): # pylint: disabl elif key.lower().startswith("dist"): distribution = val else: - raise ValueError(f"Unknown key: {key}, should be one of [dE, kT, dist]") + raise ValueError( + f"Unknown key: {key}, should be one of [dE, kT, dist]" + ) else: - raise ValueError("Mixing position arguments and keyword arguments is not allowed, either key=val or val, only") + raise ValueError( + "Mixing position arguments and keyword arguments is not allowed, either key=val or val, only" + ) try: dE = units(dE, "eV") @@ -272,16 +297,22 @@ def calc_dos(E, eig, w): else: ns._data_header.append(f"DOS T={str_T}") ns._data.append(calc_dos(E, ns._eigs[0, :, :], ns._weights)) - p.add_argument("--dos", action=DOS, nargs="*", metavar="dE, kT, DIST", - help="Calculate (and internally store) the density of states from the .EIG file, " - "dE = energy separation (5 meV), kT = smearing (300 K), DIST = distribution function (Gaussian). " - "The arguments will be the new defaults for later --dos calls.") - class Plot(argparse.Action): + p.add_argument( + "--dos", + action=DOS, + nargs="*", + metavar="dE, kT, DIST", + help="Calculate (and internally store) the density of states from the .EIG file, " + "dE = energy separation (5 meV), kT = smearing (300 K), DIST = distribution function (Gaussian). " + "The arguments will be the new defaults for later --dos calls.", + ) + class Plot(argparse.Action): def _plot_data(self, parser, ns, value, option_string=None): - """ Plot data as contained in ns._data """ + """Plot data as contained in ns._data""" from matplotlib import pyplot as plt + plt.figure() E = None @@ -312,12 +343,12 @@ def _plot_eig(self, parser, ns, value, option_string=None): import matplotlib.pyplot as plt E = ns._eigs - #Emin = np.min(E) - #Emax = np.max(E) - #n = E.shape[1] + # Emin = np.min(E) + # Emax = np.max(E) + # n = E.shape[1] # We need to setup a relatively good size of the scatter # plots - s = 10 #20. / max(Emax - Emin, n) + s = 10 # 20. / max(Emax - Emin, n) def myplot(ax, title, y, E, s): ax.set_title(title) @@ -353,26 +384,41 @@ def __call__(self, parser, ns, value, option_string=None): else: self._plot_data(parser, ns, value, option_string) - p.add_argument("--plot", "-p", action=Plot, nargs="?", metavar="FILE", - help="Plot the currently collected information (at its current invocation), or the eigenspectrum") + p.add_argument( + "--plot", + "-p", + action=Plot, + nargs="?", + metavar="FILE", + help="Plot the currently collected information (at its current invocation), or the eigenspectrum", + ) class Out(argparse.Action): def __call__(self, parser, ns, value, option_string=None): - out = value[0] if len(ns._data) == 0: # do nothing if data has not been collected - raise ValueError("No data has been collected in the arguments, nothing will be written, have you forgotten arguments?") + raise ValueError( + "No data has been collected in the arguments, nothing will be written, have you forgotten arguments?" + ) if sum((h == "Energy" for h in ns._data_header)) > 1: - raise ValueError("There are multiple non-commensurate energy-grids, saving data requires a single energy-grid.") + raise ValueError( + "There are multiple non-commensurate energy-grids, saving data requires a single energy-grid." + ) from sisl.io import tableSile - tableSile(out, mode="w").write(*ns._data, - header=ns._data_header) - p.add_argument("--out", "-o", nargs=1, action=Out, - help="Store currently collected information (at its current invocation) to the out file.") + + tableSile(out, mode="w").write(*ns._data, header=ns._data_header) + + p.add_argument( + "--out", + "-o", + nargs=1, + action=Out, + help="Store currently collected information (at its current invocation) to the out file.", + ) return p, namespace diff --git a/src/sisl/io/siesta/fa.py b/src/sisl/io/siesta/fa.py index 8b6913fbf4..eb3fb879e1 100644 --- a/src/sisl/io/siesta/fa.py +++ b/src/sisl/io/siesta/fa.py @@ -8,16 +8,16 @@ from ..sile import add_sile, sile_fh_open, sile_raise_write from .sile import SileSiesta -__all__ = ['faSileSiesta'] +__all__ = ["faSileSiesta"] @set_module("sisl.io.siesta") class faSileSiesta(SileSiesta): - """ Forces file """ + """Forces file""" @sile_fh_open() def read_force(self): - """ Reads the forces from the file """ + """Reads the forces from the file""" na = int(self.readline()) f = np.empty([na, 3], np.float64) @@ -28,8 +28,8 @@ def read_force(self): return f @sile_fh_open() - def write_force(self, f, fmt='.9e'): - """ Write forces to file + def write_force(self, f, fmt=".9e"): + """Write forces to file Parameters ---------- @@ -38,8 +38,8 @@ def write_force(self, f, fmt='.9e'): """ sile_raise_write(self) na = len(f) - self._write(f'{na}\n') - _fmt = ('{:d}' + (' {:' + fmt + '}') * 3) + '\n' + self._write(f"{na}\n") + _fmt = ("{:d}" + (" {:" + fmt + "}") * 3) + "\n" for ia in range(na): self._write(_fmt.format(ia + 1, *f[ia, :])) @@ -49,7 +49,7 @@ def write_force(self, f, fmt='.9e'): write_data = write_force -add_sile('FA', faSileSiesta, gzip=True) -add_sile('FAC', faSileSiesta, gzip=True) -add_sile('TSFA', faSileSiesta, gzip=True) -add_sile('TSFAC', faSileSiesta, gzip=True) +add_sile("FA", faSileSiesta, gzip=True) +add_sile("FAC", faSileSiesta, gzip=True) +add_sile("TSFA", faSileSiesta, gzip=True) +add_sile("TSFAC", faSileSiesta, gzip=True) diff --git a/src/sisl/io/siesta/fc.py b/src/sisl/io/siesta/fc.py index 9a6be22bf8..0ab07cdfa6 100644 --- a/src/sisl/io/siesta/fc.py +++ b/src/sisl/io/siesta/fc.py @@ -10,16 +10,16 @@ from ..sile import add_sile, sile_fh_open from .sile import SileSiesta -__all__ = ['fcSileSiesta'] +__all__ = ["fcSileSiesta"] @set_module("sisl.io.siesta") class fcSileSiesta(SileSiesta): - """ Force constant file """ + """Force constant file""" @sile_fh_open() def read_force(self, displacement=None, na=None): - """ Reads all displacement forces by multiplying with the displacement value + """Reads all displacement forces by multiplying with the displacement value Since the force constant file does not contain the non-displaced configuration this will only return forces on the displaced configurations minus the forces from @@ -55,8 +55,10 @@ def read_force(self, displacement=None, na=None): try: displacement = float(line[-1]) except Exception: - warn(f"{self.__class__.__name__}.read_force assumes displacement=0.04 Bohr!") - displacement = 0.04 * unit_convert('Bohr', 'Ang') + warn( + f"{self.__class__.__name__}.read_force assumes displacement=0.04 Bohr!" + ) + displacement = 0.04 * unit_convert("Bohr", "Ang") # Since the displacements changes sign (starting with a negative sign) # we can convert using this scheme @@ -66,7 +68,7 @@ def read_force(self, displacement=None, na=None): @sile_fh_open() def read_force_constant(self, na=None): - """ Reads the force-constant stored in the FC file + """Reads the force-constant stored in the FC file Parameters ---------- @@ -91,7 +93,7 @@ def read_force_constant(self, na=None): fc = list() while True: line = self.readline() - if line == '': + if line == "": # empty line or nothing break fc.append(list(map(float, line.split()))) @@ -109,5 +111,5 @@ def read_force_constant(self, na=None): return fc -add_sile('FC', fcSileSiesta, gzip=True) -add_sile('FCC', fcSileSiesta, gzip=True) +add_sile("FC", fcSileSiesta, gzip=True) +add_sile("FCC", fcSileSiesta, gzip=True) diff --git a/src/sisl/io/siesta/fdf.py b/src/sisl/io/siesta/fdf.py index ac8cc447fe..45a51fbd4c 100644 --- a/src/sisl/io/siesta/fdf.py +++ b/src/sisl/io/siesta/fdf.py @@ -56,7 +56,7 @@ __all__ = ["fdfSileSiesta"] -_LOGICAL_TRUE = [".true.", "true", "yes", "y", "t"] +_LOGICAL_TRUE = [".true.", "true", "yes", "y", "t"] _LOGICAL_FALSE = [".false.", "false", "no", "n", "f"] _LOGICAL = _LOGICAL_FALSE + _LOGICAL_TRUE @@ -64,7 +64,7 @@ def _parse_output_order(order, output, order_True, order_False): - """ Parses the correct order of files by examining the kwargs["order"]. + """Parses the correct order of files by examining the kwargs["order"]. If it is not present, it will use `order_{output}` to retrieve the default list. @@ -88,7 +88,7 @@ def _parse_output_order(order, output, order_True, order_False): positive_list = False for el in order: - if el.startswith('^'): + if el.startswith("^"): rem.append(el) rem.append(el[1:]) else: @@ -130,7 +130,7 @@ def _track_file(method, f, msg=None): @set_module("sisl.io.siesta") class fdfSileSiesta(SileSiesta): - """ FDF-input file + """FDF-input file By supplying base you can reference files in other directories. By default the ``base`` is the directory given in the file name. @@ -151,9 +151,9 @@ class fdfSileSiesta(SileSiesta): """ def _setup(self, *args, **kwargs): - """ Setup the `fdfSileSiesta` after initialization """ + """Setup the `fdfSileSiesta` after initialization""" super()._setup(*args, **kwargs) - self._comment = ['#', '!', ';'] + self._comment = ["#", "!", ";"] # List of parent file-handles used while reading # This is because fdf enables inclusion of other files @@ -170,7 +170,9 @@ def _pushfile(self, f): self._parent_fh.append(self.fh) self.fh = gzip.open(self.dir_file(f"{f}.gz"), mode="rt") else: - warn(f"{self!s} is trying to include file: {f} but the file seems not to exist? Will disregard file!") + warn( + f"{self!s} is trying to include file: {f} but the file seems not to exist? Will disregard file!" + ) def _popfile(self): if len(self._parent_fh) > 0: @@ -180,7 +182,7 @@ def _popfile(self): return False def _seek(self): - """ Closes all files, and starts over from beginning """ + """Closes all files, and starts over from beginning""" try: while self._popfile(): pass @@ -190,30 +192,32 @@ def _seek(self): @sile_fh_open() def includes(self): - """ Return a list of all files that are *included* or otherwise necessary for reading the fdf file """ + """Return a list of all files that are *included* or otherwise necessary for reading the fdf file""" self._seek() + # In FDF files, %include marks files that progress # down in a tree structure def add(f): f = self.dir_file(f) if f not in includes: includes.append(f) + # List of includes includes = [] l = self.readline() - while l != '': + while l != "": ls = l.split() if ls: if "%include" == ls[0].lower(): add(ls[1]) self._pushfile(ls[1]) - elif '<' in ls: + elif "<" in ls: # TODO, in principle the < could contain # include if this line is not a %block. - add(ls[ls.index('<')+1]) + add(ls[ls.index("<") + 1]) l = self.readline() - while l == '': + while l == "": # last line of file if self._popfile(): l = self.readline() @@ -224,7 +228,7 @@ def add(f): @sile_fh_open() def _read_label(self, label): - """ Try and read the first occurence of a key + """Try and read the first occurence of a key This will take care of blocks, labels and piped in labels @@ -234,8 +238,10 @@ def _read_label(self, label): label to find in the fdf file """ self._seek() + def tolabel(label): - return label.lower().replace('_', '').replace('-', '').replace('.', '') + return label.lower().replace("_", "").replace("-", "").replace(".", "") + labell = tolabel(label) def valid_line(line): @@ -254,8 +260,8 @@ def process_line(line): lsl = list(map(tolabel, ls)) # Check if there is a pipe in the line - if '<' in lsl: - idx = lsl.index('<') + if "<" in lsl: + idx = lsl.index("<") # Now there are two cases # 1. It is a block, in which case @@ -264,14 +270,16 @@ def process_line(line): if lsl[0] == "%block" and lsl[1] == labell: # Correct line found # Read the file content, removing any empty and/or comment lines - lines = self.dir_file(ls[3]).open('r').readlines() + lines = self.dir_file(ls[3]).open("r").readlines() return [l.strip() for l in lines if valid_line(l)] # 2. There are labels that should be read from a subsequent file # Label1 Label2 < other.fdf if labell in lsl[:idx]: # Valid line, read key from other.fdf - return fdfSileSiesta(self.dir_file(ls[idx+1]), base=self._directory)._read_label(label) + return fdfSileSiesta( + self.dir_file(ls[idx + 1]), base=self._directory + )._read_label(label) # It is not in this line, either key is # on the RHS of <, or the key could be "block". Say. @@ -280,7 +288,7 @@ def process_line(line): # The last case is if the label is the first word on the line # In that case we have found what we are looking for if lsl[0] == labell: - return (' '.join(ls[1:])).strip() + return (" ".join(ls[1:])).strip() elif lsl[0] == "%block": if lsl[1] == labell: @@ -296,19 +304,18 @@ def process_line(line): return lines elif lsl[0] == "%include": - # We have to open a new file self._pushfile(ls[1]) return None # Perform actual reading of line - l = self.readline().split('#')[0] + l = self.readline().split("#")[0] if len(l) == 0: return None l = process_line(l) while l is None: - l = self.readline().split('#')[0] + l = self.readline().split("#")[0] if len(l) == 0: if not self._popfile(): return None @@ -318,7 +325,7 @@ def process_line(line): @classmethod def _type(cls, value): - """ Determine the type by the value + """Determine the type by the value Parameters ---------- @@ -330,11 +337,11 @@ def _type(cls, value): if isinstance(value, list): # A block, %block ... - return 'B' + return "B" if isinstance(value, np.ndarray): # A list, Label [...] - return 'a' + return "a" # Grab the entire line (beside the key) values = value.split() @@ -342,14 +349,14 @@ def _type(cls, value): fdf = values[0].lower() if fdf in _LOGICAL: # logical - return 'b' + return "b" try: float(fdf) - if '.' in fdf: + if "." in fdf: # a real number (otherwise an integer) - return 'r' - return 'i' + return "r" + return "i" except Exception: pass # fall-back to name with everything @@ -358,15 +365,15 @@ def _type(cls, value): # possibly a physical value try: float(values[0]) - return 'p' + return "p" except Exception: pass - return 'n' + return "n" @sile_fh_open() def type(self, label): - """ Return the type of the fdf-keyword + """Return the type of the fdf-keyword Parameters ---------- @@ -378,7 +385,7 @@ def type(self, label): @sile_fh_open() def get(self, label, default=None, unit=None, with_unit=False): - """ Retrieve fdf-keyword from the file + """Retrieve fdf-keyword from the file Parameters ---------- @@ -430,19 +437,19 @@ def get(self, label, default=None, unit=None, with_unit=False): # We will only do something if it is a real, int, or physical. # Else we simply return, as-is - if t == 'r': + if t == "r": if default is None: return float(value) t = type(default) return t(value) - elif t == 'i': + elif t == "i": if default is None: return int(value) t = type(default) return t(value) - elif t == 'p': + elif t == "p": value = value.split() if with_unit: # Simply return, as is. Let the user do whatever. @@ -451,18 +458,20 @@ def get(self, label, default=None, unit=None, with_unit=False): default = unit_default(unit_group(value[1])) else: if unit_group(value[1]) != unit_group(unit): - raise ValueError(f"Requested unit for {label} is not the same type. " - "Found/Requested {value[1]}/{unit}") + raise ValueError( + f"Requested unit for {label} is not the same type. " + "Found/Requested {value[1]}/{unit}" + ) default = unit return float(value[0]) * unit_convert(value[1], default) - elif t == 'b': + elif t == "b": return value.lower() in _LOGICAL_TRUE return value def set(self, key, value, keep=True): - """ Add the key and value to the FDF file + """Add the key and value to the FDF file Parameters ---------- @@ -483,7 +492,7 @@ def set(self, key, value, keep=True): # 1. find the old value, and thus the file in which it is found if isfile(top_file): - same_fdf = self.__class__(top_file, 'r') + same_fdf = self.__class__(top_file, "r") try: same_fdf.get(key) # Get the file of the containing data @@ -496,24 +505,28 @@ def set(self, key, value, keep=True): self._open() # Now we should re-read and edit the file - lines = open(top_file, 'r').readlines() + lines = open(top_file, "r").readlines() def write(fh, value): if value is None: return fh.write(self.print(key, value)) - if isinstance(value, str) and '\n' not in value: - fh.write('\n') + if isinstance(value, str) and "\n" not in value: + fh.write("\n") # Now loop, write and edit do_write = True lkey = key.lower() - with open(top_file, 'w') as fh: + with open(top_file, "w") as fh: for line in lines: if self.line_has_key(line, lkey, case=False) and do_write: write(fh, value) if keep: - fh.write("# Old value ({})\n".format(datetime.today().strftime("%Y-%m-%d %H:%M"))) + fh.write( + "# Old value ({})\n".format( + datetime.today().strftime("%Y-%m-%d %H:%M") + ) + ) fh.write(f"{line}") do_write = False else: @@ -526,7 +539,7 @@ def write(fh, value): @staticmethod def print(key, value): - """ Return a string which is pretty-printing the key+value """ + """Return a string which is pretty-printing the key+value""" if isinstance(value, list): s = f"%block {key}" # if the value has any new-values @@ -551,7 +564,7 @@ def print(key, value): @sile_fh_open() def write_lattice(self, sc, fmt=".8f", *args, **kwargs): - """ Writes the supercell + """Writes the supercell Parameters ---------- @@ -567,7 +580,7 @@ def write_lattice(self, sc, fmt=".8f", *args, **kwargs): fmt_str = " {{0:{0}}} {{1:{0}}} {{2:{0}}}\n".format(fmt) unit = kwargs.get("unit", "Ang").capitalize() - conv = 1. + conv = 1.0 if unit in ("Ang", "Bohr"): conv = unit_convert("Ang", unit) else: @@ -583,7 +596,7 @@ def write_lattice(self, sc, fmt=".8f", *args, **kwargs): @sile_fh_open() def write_geometry(self, geometry, fmt=".8f", *args, **kwargs): - """ Writes the geometry + """Writes the geometry Parameters ---------- @@ -616,12 +629,14 @@ def write_geometry(self, geometry, fmt=".8f", *args, **kwargs): xyz = geometry.fxyz else: xyz = geometry.xyz * conv - if fmt[0] == '.': + if fmt[0] == ".": # Correct for a "same" length of all coordinates c_max = len(str((f"{{:{fmt}}}").format(xyz.max()))) c_min = len(str((f"{{:{fmt}}}").format(xyz.min()))) fmt = str(max(c_min, c_max)) + fmt - fmt_str = " {{3:{0}}} {{4:{0}}} {{5:{0}}} {{0}} # {{1:{1}d}}: {{2}}\n".format(fmt, len(str(len(geometry)))) + fmt_str = " {{3:{0}}} {{4:{0}}} {{5:{0}}} {{0}} # {{1:{1}d}}: {{2}}\n".format( + fmt, len(str(len(geometry))) + ) for ia, a, isp in geometry.iter_species(): self._write(fmt_str.format(isp + 1, ia + 1, a.tag, *xyz[ia, :])) @@ -639,6 +654,7 @@ def write_geometry(self, geometry, fmt=".8f", *args, **kwargs): self._write("%endblock ChemicalSpeciesLabel\n") _write_block = True + def write_block(atoms, append, write_block): if write_block: self._write("\n# Constraints\n%block Geometry.Constraints\n") @@ -652,7 +668,9 @@ def write_block(atoms, append, write_block): if n in geometry.names: idx = list2str(geometry.names[n] + 1).replace("-", " -- ") if len(idx) > 200: - info(f"{self!s}.write_geometry will not write the constraints for {n} (too long line).") + info( + f"{self!s}.write_geometry will not write the constraints for {n} (too long line)." + ) else: _write_block = write_block(idx, append, _write_block) @@ -661,7 +679,7 @@ def write_block(atoms, append, write_block): @sile_fh_open() def write_brillouinzone(self, bz): - r""" Writes Brillouinzone information to the fdf input file + r"""Writes Brillouinzone information to the fdf input file The `bz` object will be written as options in the input file. The class of `bz` decides which options gets written. For instance @@ -679,33 +697,41 @@ def write_brillouinzone(self, bz): sile_raise_write(self) if not isinstance(bz, BandStructure): - raise NotImplementedError(f"{self.__class__.__name__}.write_brillouinzone only implements BandStructure object writing.") + raise NotImplementedError( + f"{self.__class__.__name__}.write_brillouinzone only implements BandStructure object writing." + ) self._write("BandLinesScale ReciprocalLatticeVectors\n%block BandLines\n") ip = 1 for divs, point, name in zip(bz.divisions, bz.points, bz.names): - self._write(f" {ip} {point[0]:.5f} {point[1]:.5f} {point[2]:.5f} {name}\n") + self._write( + f" {ip} {point[0]:.5f} {point[1]:.5f} {point[2]:.5f} {name}\n" + ) ip = divs point, name = bz.points[-1], bz.names[-1] self._write(f" {ip} {point[0]:.5f} {point[1]:.5f} {point[2]:.5f} {name}\n") self._write("%endblock BandLines\n") def read_lattice_nsc(self, *args, **kwargs): - """ Read supercell size using any method available + """Read supercell size using any method available Raises ------ SislWarning if none of the files can be read """ - order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "ORB_INDX"], []) + order = _parse_output_order( + kwargs.pop("order", None), True, ["nc", "ORB_INDX"], [] + ) for f in order: v = getattr(self, f"_r_lattice_nsc_{f.lower()}")(*args, **kwargs) if v is not None: _track(self.read_lattice_nsc, f"found file {f}") return v - warn("number of supercells could not be read from output files. Assuming molecule cell " - "(no supercell connections)") + warn( + "number of supercells could not be read from output files. Assuming molecule cell " + "(no supercell connections)" + ) return _a.onesi(3) def _r_lattice_nsc_nc(self, *args, **kwargs): @@ -723,7 +749,7 @@ def _r_lattice_nsc_orb_indx(self, *args, **kwargs): return None def read_lattice(self, output=False, *args, **kwargs): - """ Returns Lattice object by reading fdf or Siesta output related files. + """Returns Lattice object by reading fdf or Siesta output related files. One can limit the tried files to only one file by passing only a single file ending. @@ -746,9 +772,9 @@ def read_lattice(self, output=False, *args, **kwargs): >>> fdf.read_lattice(order=["nc"]) # read from [nc] >>> fdf.read_lattice(True, order=["nc"]) # read from [nc] """ - order = _parse_output_order(kwargs.pop("order", None), output, - "nc XV TSHS fdf".split(), - "fdf") + order = _parse_output_order( + kwargs.pop("order", None), output, "nc XV TSHS fdf".split(), "fdf" + ) for f in order: v = getattr(self, f"_r_lattice_{f.lower()}")(*args, **kwargs) if v is not None: @@ -757,7 +783,7 @@ def read_lattice(self, output=False, *args, **kwargs): return None def _r_lattice_fdf(self, *args, **kwargs): - """ Returns `Lattice` object from the FDF file """ + """Returns `Lattice` object from the FDF file""" s = self.get("LatticeConstant", unit="Ang") if s is None: raise SileError("Could not find LatticeConstant in file") @@ -776,7 +802,9 @@ def _r_lattice_fdf(self, *args, **kwargs): cell = Lattice.tocell(*tmp) if lc is None: # the fdf file contains neither the latticevectors or parameters - raise SileError("Could not find LatticeVectors or LatticeParameters block in file") + raise SileError( + "Could not find LatticeVectors or LatticeParameters block in file" + ) cell *= s # When reading from the fdf, the warning should be suppressed @@ -795,7 +823,7 @@ def _r_lattice_nc(self): return None def _r_lattice_xv(self, *args, **kwargs): - """ Returns `Lattice` object from the XV file """ + """Returns `Lattice` object from the XV file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".XV") _track_file(self._r_lattice_xv, f) if f.is_file(): @@ -806,7 +834,7 @@ def _r_lattice_xv(self, *args, **kwargs): return None def _r_lattice_struct(self, *args, **kwargs): - """ Returns `Lattice` object from the STRUCT files """ + """Returns `Lattice` object from the STRUCT files""" for end in ["STRUCT_NEXT_ITER", "STRUCT_OUT", "STRUCT_IN"]: f = self.dir_file(self.get("SystemLabel", default="siesta") + f".{end}") _track_file(self._r_lattice_struct, f) @@ -832,7 +860,7 @@ def _r_lattice_onlys(self, *args, **kwargs): return None def read_force(self, *args, **kwargs): - """ Read forces from the output of the calculation (forces are not defined in the input) + """Read forces from the output of the calculation (forces are not defined in the input) Parameters ---------- @@ -853,7 +881,7 @@ def read_force(self, *args, **kwargs): return None def _r_force_fa(self, *args, **kwargs): - """ Read forces from the FA file """ + """Read forces from the FA file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".FA") _track_file(self._r_force_fa, f) if f.is_file(): @@ -861,7 +889,7 @@ def _r_force_fa(self, *args, **kwargs): return None def _r_force_fac(self, *args, **kwargs): - """ Read forces from the FAC file """ + """Read forces from the FAC file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".FAC") _track_file(self._r_force_fac, f) if f.is_file(): @@ -869,7 +897,7 @@ def _r_force_fac(self, *args, **kwargs): return None def _r_force_tsfa(self, *args, **kwargs): - """ Read forces from the TSFA file """ + """Read forces from the TSFA file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".TSFA") _track_file(self._r_force_tsfa, f) if f.is_file(): @@ -877,7 +905,7 @@ def _r_force_tsfa(self, *args, **kwargs): return None def _r_force_tsfac(self, *args, **kwargs): - """ Read forces from the TSFAC file """ + """Read forces from the TSFAC file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".TSFAC") _track_file(self._r_force_tsfac, f) if f.is_file(): @@ -885,7 +913,7 @@ def _r_force_tsfac(self, *args, **kwargs): return None def _r_force_nc(self, *args, **kwargs): - """ Read forces from the nc file """ + """Read forces from the nc file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".nc") _track_file(self._r_force_nc, f) if f.is_file(): @@ -893,7 +921,7 @@ def _r_force_nc(self, *args, **kwargs): return None def read_force_constant(self, *args, **kwargs): - """ Read force constant from the output of the calculation + """Read force constant from the output of the calculation Returns ------- @@ -928,7 +956,7 @@ def _r_force_constant_fc(self, *args, **kwargs): return None def read_fermi_level(self, *args, **kwargs): - """ Read fermi-level from output of the calculation + """Read fermi-level from output of the calculation Parameters ---------- @@ -941,7 +969,9 @@ def read_fermi_level(self, *args, **kwargs): Ef : float fermi-level """ - order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "TSDE", "TSHS", "EIG", "bands"], []) + order = _parse_output_order( + kwargs.pop("order", None), True, ["nc", "TSDE", "TSHS", "EIG", "bands"], [] + ) for f in order: v = getattr(self, f"_r_fermi_level_{f.lower()}")(*args, **kwargs) if v is not None: @@ -986,7 +1016,7 @@ def _r_fermi_level_bands(self): return None def read_dynamical_matrix(self, *args, **kwargs): - """ Read dynamical matrix from output of the calculation + """Read dynamical matrix from output of the calculation Generally the mass is stored in the basis information output, but for dynamical matrices it makes sense to let the user control this, @@ -1049,7 +1079,10 @@ def _r_dynamical_matrix_fc(self, *args, **kwargs): geom.atoms.replace(i, atom) # Get list of FC atoms - FC_atoms = _a.arangei(self.get("MD.FCFirst", default=1) - 1, self.get("MD.FCLast", default=geom.na)) + FC_atoms = _a.arangei( + self.get("MD.FCFirst", default=1) - 1, + self.get("MD.FCLast", default=geom.na), + ) return self._dynamical_matrix_from_fc(geom, FC, FC_atoms, *args, **kwargs) def _r_dynamical_matrix_nc(self, *args, **kwargs): @@ -1066,7 +1099,10 @@ def _r_dynamical_matrix_nc(self, *args, **kwargs): # Get list of FC atoms # TODO change to read in from the NetCDF file - FC_atoms = _a.arangei(self.get("MD.FCFirst", default=1) - 1, self.get("MD.FCLast", default=geom.na)) + FC_atoms = _a.arangei( + self.get("MD.FCFirst", default=1) - 1, + self.get("MD.FCLast", default=geom.na), + ) return self._dynamical_matrix_from_fc(geom, FC, FC_atoms, *args, **kwargs) def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): @@ -1080,21 +1116,21 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): hermitian = kwargs.get("hermitian", True) # Figure out the "original" periodic directions - #periodic = geom.nsc > 1 + # periodic = geom.nsc > 1 # Create conversion from eV/Ang^2 to correct units # Further down we are multiplying with [1 / amu] scale = constant.hbar / units("Ang", "m") / units("eV amu", "J kg") ** 0.5 # Cut-off too small values - fc_cut = kwargs.get("cutoff", 0.) - FC = np.where(np.fabs(FC) > fc_cut, FC, 0.) + fc_cut = kwargs.get("cutoff", 0.0) + FC = np.where(np.fabs(FC) > fc_cut, FC, 0.0) # Convert the force constant such that a diagonalization returns eV ^ 2 # FC is in [eV / Ang^2] # Convert the geometry to contain 3 orbitals per atom (x, y, z) - R = kwargs.get("cutoff_dist", -2.) + R = kwargs.get("cutoff_dist", -2.0) orbs = [Orbital(R / 2, tag=tag) for tag in "xyz"] with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -1110,8 +1146,10 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): supercell = [1] * 3 elif supercell is True: _, supercell = geom.as_primary(FC.shape[0], ret_super=True) - info("{}.read_dynamical_matrix(FC) guessed on a [{}, {}, {}] " - "supercell calculation.".format(str(self), *supercell)) + info( + "{}.read_dynamical_matrix(FC) guessed on a [{}, {}, {}] " + "supercell calculation.".format(str(self), *supercell) + ) # Convert to integer array supercell = _a.asarrayi(supercell) @@ -1156,13 +1194,14 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): geom_small = Geometry(geom.xyz[FC_atoms], geom.atoms[FC_atoms], lattice) # Convert the big geometry's coordinates to fractional coordinates of the small unit-cell. - isc_xyz = (geom.xyz.dot(geom_small.lattice.icell.T) - - np.tile(geom_small.fxyz, (np.product(supercell), 1))) + isc_xyz = geom.xyz.dot(geom_small.lattice.icell.T) - np.tile( + geom_small.fxyz, (np.product(supercell), 1) + ) axis_tiling = [] offset = len(geom_small) for _ in (supercell > 1).nonzero()[0]: - first_isc = (np.around(isc_xyz[FC_atoms + offset, :]) == 1.).sum(0) + first_isc = (np.around(isc_xyz[FC_atoms + offset, :]) == 1.0).sum(0) axis_tiling.append(np.argmax(first_isc)) # Fix the offset and wrap-around offset = (offset * supercell[axis_tiling[-1]]) % na_full @@ -1179,9 +1218,11 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): # Proximity check of 0.01 Ang (TODO add this as an argument) for ax in range(3): daxis = geom_tile.xyz[:, ax] - geom.xyz[:, ax] - if not np.allclose(daxis, daxis[0], rtol=0., atol=0.01): - raise SislError(f"{self!s}.read_dynamical_matrix(FC) could " - "not figure out the tiling method for the supercell") + if not np.allclose(daxis, daxis[0], rtol=0.0, atol=0.01): + raise SislError( + f"{self!s}.read_dynamical_matrix(FC) could " + "not figure out the tiling method for the supercell" + ) # Convert the FC matrix to a "rollable" matrix # This will make it easier to symmetrize @@ -1213,25 +1254,33 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): # Now swap the [2, 3, 4] dimensions so that we get in order of lattice vectors # x, y, z - FC = np.transpose(FC, (0, 1, *(axis_tiling.index(i)+2 for i in range(3)), 5, 6)) + FC = np.transpose( + FC, (0, 1, *(axis_tiling.index(i) + 2 for i in range(3)), 5, 6) + ) del axis_tiling # Now FC is sorted according to the supercell tiling # TODO this will probably fail if: FC_atoms.size != FC.shape[5] from ._help import _fc_correct - FC = _fc_correct(FC, trans_inv=kwargs.get("trans_inv", True), - sum0=kwargs.get("sum0", True), - hermitian=hermitian) + + FC = _fc_correct( + FC, + trans_inv=kwargs.get("trans_inv", True), + sum0=kwargs.get("sum0", True), + hermitian=hermitian, + ) # Remove ghost-atoms or atoms with 0 mass! # TODO check if ghost-atoms should be taken into account in _fc_correct - idx = (geom.atoms.mass == 0.).nonzero()[0] + idx = (geom.atoms.mass == 0.0).nonzero()[0] if len(idx) > 0: - #FC = np.delete(FC, idx, axis=5) - #geom = geom.remove(idx) - #geom.set_nsc([1] * 3) - raise NotImplementedError(f"{self}.read_dynamical_matrix could not reduce geometry " - "since there are atoms with 0 mass.") + # FC = np.delete(FC, idx, axis=5) + # geom = geom.remove(idx) + # geom.set_nsc([1] * 3) + raise NotImplementedError( + f"{self}.read_dynamical_matrix could not reduce geometry " + "since there are atoms with 0 mass." + ) # Now we can build the dynamical matrix (it will always be real) @@ -1241,13 +1290,12 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): FC = np.squeeze(FC, axis=(2, 3, 4)) # Instead of doing the sqrt in all D = FC (below) we do it here - m = scale / geom.atoms.mass ** 0.5 + m = scale / geom.atoms.mass**0.5 FC *= m[FC_atoms].reshape(-1, 1, 1, 1) * m.reshape(1, 1, -1, 1) j_FC_atoms = FC_atoms idx = _a.arangei(len(FC_atoms)) for ia, fia in enumerate(FC_atoms): - if R > 0: # find distances between the other atoms to cut-off the distance idx = geom.close(fia, R=R, atoms=FC_atoms) @@ -1255,13 +1303,15 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): j_FC_atoms = FC_atoms[idx] for ja, fja in zip(idx, j_FC_atoms): - D[ia*3:(ia+1)*3, ja*3:(ja+1)*3] = FC[ia, :, fja, :] + D[ia * 3 : (ia + 1) * 3, ja * 3 : (ja + 1) * 3] = FC[ia, :, fja, :] else: geom = geom_small if np.any(np.diff(FC_atoms) != 1): - raise SislError(f"{self}.read_dynamical_matrix(FC) requires the FC atoms to be consecutive!") + raise SislError( + f"{self}.read_dynamical_matrix(FC) requires the FC atoms to be consecutive!" + ) # Re-order FC matrix so the FC-atoms are first if FC.shape[0] != FC.shape[5]: @@ -1273,10 +1323,12 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): if FC_atoms[0] != 0: # TODO we could roll the axis such that the displaced atoms moves into the # first elements - raise SislError(f"{self}.read_dynamical_matrix(FC) requires the displaced atoms to start from 1!") + raise SislError( + f"{self}.read_dynamical_matrix(FC) requires the displaced atoms to start from 1!" + ) # After having done this we can easily mass scale all FC components - m = scale / geom.atoms.mass ** 0.5 + m = scale / geom.atoms.mass**0.5 FC *= m.reshape(-1, 1, 1, 1, 1, 1, 1) * m.reshape(1, 1, 1, 1, 1, -1, 1) # Check whether we need to "halve" the equivalent supercell @@ -1325,10 +1377,14 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): for ia in iter_FC_atoms: # Reduce second loop based on radius cutoff if R > 0: - iter_j_FC_atoms = iter_FC_atoms[dist(ia, aoff + iter_FC_atoms) <= R] + iter_j_FC_atoms = iter_FC_atoms[ + dist(ia, aoff + iter_FC_atoms) <= R + ] for ja in iter_j_FC_atoms: - D[ia*3:(ia+1)*3, joff+ja*3:joff+(ja+1)*3] += FC[ia, :, x, y, z, ja, :] + D[ + ia * 3 : (ia + 1) * 3, joff + ja * 3 : joff + (ja + 1) * 3 + ] += FC[ia, :, x, y, z, ja, :] D = D.tocsr() # Remove all zeros @@ -1341,7 +1397,7 @@ def _dynamical_matrix_from_fc(self, geom, FC, FC_atoms, *args, **kwargs): return D def read_geometry(self, output=False, *args, **kwargs): - """ Returns Geometry object by reading fdf or Siesta output related files. + """Returns Geometry object by reading fdf or Siesta output related files. One can limit the tried files to only one file by passing only a single file ending. @@ -1369,9 +1425,12 @@ def read_geometry(self, output=False, *args, **kwargs): # When adding more capabilities please check the read_geometry(order=...) in this # code to correct. ## - order = _parse_output_order(kwargs.pop("order", None), output, - "XV nc TSHS STRUCT fdf HSX".split(), - "fdf") + order = _parse_output_order( + kwargs.pop("order", None), + output, + "XV nc TSHS STRUCT fdf HSX".split(), + "fdf", + ) for f in order: v = getattr(self, f"_r_geometry_{f.lower()}")(*args, **kwargs) if v is not None: @@ -1381,7 +1440,7 @@ def read_geometry(self, output=False, *args, **kwargs): return None def _r_geometry_xv(self, *args, **kwargs): - """ Returns `Geometry` object from the XV file """ + """Returns `Geometry` object from the XV file""" geom = None f = self.dir_file(self.get("SystemLabel", default="siesta") + ".XV") _track_file(self._r_geometry_xv, f) @@ -1394,14 +1453,14 @@ def _r_geometry_xv(self, *args, **kwargs): with warnings.catch_warnings(): warnings.simplefilter("ignore") for atom, _ in geom.atoms.iter(True): - geom.atoms.replace(atom, basis[atom.Z-1]) + geom.atoms.replace(atom, basis[atom.Z - 1]) geom.reduce() nsc = self.read_lattice_nsc() geom.set_nsc(nsc) return geom def _r_geometry_struct(self, *args, **kwargs): - """ Returns `Geometry` object from the STRUCT_* files """ + """Returns `Geometry` object from the STRUCT_* files""" geom = None for end in ["STRUCT_NEXT_ITER", "STRUCT_OUT", "STRUCT_IN"]: f = self.dir_file(self.get("SystemLabel", default="siesta") + f".{end}") @@ -1415,7 +1474,7 @@ def _r_geometry_struct(self, *args, **kwargs): with warnings.catch_warnings(): warnings.simplefilter("ignore") for atom, _ in geom.atoms.iter(True): - geom.atoms.replace(atom, basis[atom.Z-1]) + geom.atoms.replace(atom, basis[atom.Z - 1]) geom.reduce() nsc = self.read_lattice_nsc() geom.set_nsc(nsc) @@ -1449,7 +1508,7 @@ def _r_geometry_hsx(self): return None def _r_geometry_fdf(self, *args, **kwargs): - """ Returns Geometry object from the FDF file + """Returns Geometry object from the FDF file NOTE: Interaction range of the Atoms are currently not read. """ @@ -1461,7 +1520,7 @@ def _r_geometry_fdf(self, *args, **kwargs): # Read atom scaling lc = self.get("AtomicCoordinatesFormat", default="Bohr").lower() if "ang" in lc or "notscaledcartesianang" in lc: - s = 1. + s = 1.0 elif "bohr" in lc or "notscaledcartesianbohr" in lc: s = Bohr2Ang elif "scaledcartesian" in lc: @@ -1470,7 +1529,7 @@ def _r_geometry_fdf(self, *args, **kwargs): elif "fractional" in lc or "scaledbylatticevectors" in lc: # no scaling of coordinates as that is entirely # done by the latticevectors - s = 1. + s = 1.0 is_frac = True # If the user requests a shifted geometry @@ -1489,7 +1548,9 @@ def _r_geometry_fdf(self, *args, **kwargs): # Read atom block atms = self.get("AtomicCoordinatesAndAtomicSpecies") if atms is None: - raise SileError("AtomicCoordinatesAndAtomicSpecies block could not be found") + raise SileError( + "AtomicCoordinatesAndAtomicSpecies block could not be found" + ) # Read number of atoms and block # We default to the number of elements in the @@ -1501,7 +1562,9 @@ def _r_geometry_fdf(self, *args, **kwargs): # align number of atoms and atms array atms = atms[:na] elif na > len(atms): - raise SileError("NumberOfAtoms is larger than the atoms defined in the blocks") + raise SileError( + "NumberOfAtoms is larger than the atoms defined in the blocks" + ) elif na == 0: raise SileError("NumberOfAtoms has been determined to be zero, no atoms.") @@ -1519,7 +1582,9 @@ def _r_geometry_fdf(self, *args, **kwargs): # Read the block (not strictly needed, if so we simply set all atoms to H) atoms = self.read_basis() if atoms is None: - warn("Block ChemicalSpeciesLabel does not exist, cannot determine the basis (all Hydrogen).") + warn( + "Block ChemicalSpeciesLabel does not exist, cannot determine the basis (all Hydrogen)." + ) # Default atom (hydrogen) atoms = Atom(1) @@ -1537,21 +1602,21 @@ def _r_geometry_fdf(self, *args, **kwargs): w /= w.sum() origin = lattice.cell.sum(0) * 0.5 - np.average(xyz, 0, weights=w) elif opt.startswith("min"): - origin = - np.amin(xyz, 0) + origin = -np.amin(xyz, 0) if len(opt) > 4: opt = opt[4:] if opt == "x": - origin[1:] = 0. + origin[1:] = 0.0 elif opt == "y": - origin[[0, 2]] = 0. + origin[[0, 2]] = 0.0 elif opt == "z": - origin[:2] = 0. + origin[:2] = 0.0 elif opt in ("xy", "yx"): - origin[2] = 0. + origin[2] = 0.0 elif opt in ("xz", "zx"): - origin[1] = 0. + origin[1] = 0.0 elif opt in ("yz", "zy"): - origin[0] = 0. + origin[0] = 0.0 # create geometry xyz += origin @@ -1562,15 +1627,16 @@ def _r_geometry_fdf(self, *args, **kwargs): if supercell is not None: # we need to expand # check that we are only dealing with an orthogonal supercell - supercell = np.array([[int(x) for x in line.split()] - for line in supercell]) + supercell = np.array([[int(x) for x in line.split()] for line in supercell]) assert supercell.shape == (3, 3) # Check it is diagonal diag = np.diag(supercell) if not np.allclose(supercell - np.diag(diag), 0): - raise SileError("Lattice input is not diagonal, currently not implemented in sisl") + raise SileError( + "Lattice input is not diagonal, currently not implemented in sisl" + ) # now tile it for axis, nt in enumerate(diag): @@ -1579,7 +1645,7 @@ def _r_geometry_fdf(self, *args, **kwargs): return geom def read_grid(self, name, *args, **kwargs): - """ Read grid related information from any of the output files + """Read grid related information from any of the output files The order of the readed data is shown below. @@ -1596,9 +1662,13 @@ def read_grid(self, name, *args, **kwargs): the order of which to try and read the geometry. By default this is ``["nc", "grid.nc", "bin"]`` (bin refers to the binary files) """ - order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "grid.nc", "bin"], []) + order = _parse_output_order( + kwargs.pop("order", None), True, ["nc", "grid.nc", "bin"], [] + ) for f in order: - v = getattr(self, f"_r_grid_{f.lower().replace('.', '_')}")(name, *args, **kwargs) + v = getattr(self, f"_r_grid_{f.lower().replace('.', '_')}")( + name, *args, **kwargs + ) if v is not None: if self.track: info(f"{self.file}(read_grid) found in file={f}") @@ -1611,50 +1681,52 @@ def _r_grid_nc(self, name, *args, **kwargs): _track_file(self._r_grid_nc, f) if f.is_file(): # Capitalize correctly - name = {"rho": "Rho", - "rhoinit": "RhoInit", - "vna": "Vna", - "ioch": "Chlocal", - "chlocal": "Chlocal", - "toch": "RhoTot", - "totalcharge": "RhoTot", - "rhotot": "RhoTot", - "drho": "RhoDelta", - "deltarho": "RhoDelta", - "rhodelta": "RhoDelta", - "vh": "Vh", - "electrostaticpotential": "Vh", - "rhoxc": "RhoXC", - "vt": "Vt", - "totalpotential": "Vt", - "bader": "RhoBader", - "baderrho": "RhoBader", - "rhobader": "RhoBader" + name = { + "rho": "Rho", + "rhoinit": "RhoInit", + "vna": "Vna", + "ioch": "Chlocal", + "chlocal": "Chlocal", + "toch": "RhoTot", + "totalcharge": "RhoTot", + "rhotot": "RhoTot", + "drho": "RhoDelta", + "deltarho": "RhoDelta", + "rhodelta": "RhoDelta", + "vh": "Vh", + "electrostaticpotential": "Vh", + "rhoxc": "RhoXC", + "vt": "Vt", + "totalpotential": "Vt", + "bader": "RhoBader", + "baderrho": "RhoBader", + "rhobader": "RhoBader", }.get(name.lower()) return ncSileSiesta(f).read_grid(name, **kwargs) return None def _r_grid_grid_nc(self, name, *args, **kwargs): # Read grid from the <>.nc file - name = {"rho": "Rho", - "rhoinit": "RhoInit", - "vna": "Vna", - "ioch": "Chlocal", - "chlocal": "Chlocal", - "toch": "TotalCharge", - "totalcharge": "TotalCharge", - "rhotot": "TotalCharge", - "drho": "DeltaRho", - "deltarho": "DeltaRho", - "rhodelta": "DeltaRho", - "vh": "ElectrostaticPotential", - "electrostaticpotential": "ElectrostaticPotential", - "rhoxc": "RhoXC", - "vt": "TotalPotential", - "totalpotential": "TotalPotential", - "bader": "BaderCharge", - "baderrho": "BaderCharge", - "rhobader": "BaderCharge" + name = { + "rho": "Rho", + "rhoinit": "RhoInit", + "vna": "Vna", + "ioch": "Chlocal", + "chlocal": "Chlocal", + "toch": "TotalCharge", + "totalcharge": "TotalCharge", + "rhotot": "TotalCharge", + "drho": "DeltaRho", + "deltarho": "DeltaRho", + "rhodelta": "DeltaRho", + "vh": "ElectrostaticPotential", + "electrostaticpotential": "ElectrostaticPotential", + "rhoxc": "RhoXC", + "vt": "TotalPotential", + "totalpotential": "TotalPotential", + "bader": "BaderCharge", + "baderrho": "BaderCharge", + "rhobader": "BaderCharge", }.get(name.lower()) + ".grid.nc" f = self.dir_file(name) @@ -1667,25 +1739,26 @@ def _r_grid_grid_nc(self, name, *args, **kwargs): def _r_grid_bin(self, name, *args, **kwargs): # Read grid from the <>.VT/... file - name = {"rho": ".RHO", - "rhoinit": ".RHOINIT", - "vna": ".VNA", - "ioch": ".IOCH", - "chlocal": ".IOCH", - "toch": ".TOCH", - "totalcharge": ".TOCH", - "rhotot": ".TOCH", - "drho": ".DRHO", - "deltarho": ".DRHO", - "rhodelta": ".DRHO", - "vh": ".VH", - "electrostaticpotential": ".VH", - "rhoxc": ".RHOXC", - "vt": ".VT", - "totalpotential": ".VT", - "bader": ".BADER", - "baderrho": ".BADER", - "rhobader": ".BADER" + name = { + "rho": ".RHO", + "rhoinit": ".RHOINIT", + "vna": ".VNA", + "ioch": ".IOCH", + "chlocal": ".IOCH", + "toch": ".TOCH", + "totalcharge": ".TOCH", + "rhotot": ".TOCH", + "drho": ".DRHO", + "deltarho": ".DRHO", + "rhodelta": ".DRHO", + "vh": ".VH", + "electrostaticpotential": ".VH", + "rhoxc": ".RHOXC", + "vt": ".VT", + "totalpotential": ".VT", + "bader": ".BADER", + "baderrho": ".BADER", + "rhobader": ".BADER", }.get(name.lower()) f = self.dir_file(self.get("SystemLabel", default="siesta") + name) @@ -1697,7 +1770,7 @@ def _r_grid_bin(self, name, *args, **kwargs): return None def read_basis(self, *args, **kwargs): - """ Read the atomic species and figure out the number of atomic orbitals in their basis + """Read the atomic species and figure out the number of atomic orbitals in their basis The order of the read is shown below. @@ -1710,7 +1783,9 @@ def read_basis(self, *args, **kwargs): the order of which to try and read the basis information. By default this is ``["nc", "ion", "ORB_INDX", "fdf"]`` """ - order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "ion", "ORB_INDX", "fdf"], []) + order = _parse_output_order( + kwargs.pop("order", None), True, ["nc", "ion", "ORB_INDX", "fdf"], [] + ) for f in order: v = getattr(self, f"_r_basis_{f.lower()}")(*args, **kwargs) if v is not None: @@ -1741,7 +1816,7 @@ def _r_basis_ion(self): found_all = True for spc in spcs: idx, Z, lbl = spc.split()[:3] - idx = int(idx) - 1 # F-indexing + idx = int(idx) - 1 # F-indexing Z = int(Z) lbl = lbl.strip() f = self.dir_file(lbl + ".ext") @@ -1759,8 +1834,10 @@ def _r_basis_ion(self): found_all = False if found_one and not found_all: - warn("Siesta basis information could not read all ion.nc/ion.xml files. " - "Only a subset of the basis information is accessible.") + warn( + "Siesta basis information could not read all ion.nc/ion.xml files. " + "Only a subset of the basis information is accessible." + ) elif not found_one: return None return atoms @@ -1769,7 +1846,9 @@ def _r_basis_orb_indx(self): f = self.dir_file(self.get("SystemLabel", default="siesta") + ".ORB_INDX") _track_file(self._r_basis_orb_indx, f) if f.is_file(): - info(f"Siesta basis information is read from {f}, the radial functions are not accessible.") + info( + f"Siesta basis information is read from {f}, the radial functions are not accessible." + ) return orbindxSileSiesta(f).read_basis(atoms=self._r_basis_fdf()) return None @@ -1794,7 +1873,7 @@ def _r_basis_fdf(self): # Now spcs contains the block of the chemicalspecieslabel for spc in spcs: idx, Z, lbl = spc.split()[:3] - idx = int(idx) - 1 # F-indexing + idx = int(idx) - 1 # F-indexing Z = int(Z) lbl = lbl.strip() @@ -1824,7 +1903,7 @@ def _r_basis_fdf(self): @classmethod def _parse_pao_basis(cls, block, specie=None): - """ Parse the full PAO.Basis block with *optionally* only a single specie + """Parse the full PAO.Basis block with *optionally* only a single specie Notes ----- @@ -1898,7 +1977,7 @@ def parse_next(): nl_line = nl_line.replace("n=", "").split() # first 3|2: are n?, l, Nzeta - n = None # use default n defined in AtomitOrbital + n = None # use default n defined in AtomitOrbital first = int(nl_line.pop(0)) second = int(nl_line.pop(0)) try: @@ -1944,8 +2023,8 @@ def parse_next(): for izeta in range(orb.zeta, orb.zeta + nzeta): orb = SphericalOrbital(l, None, R=rc) orbs.extend(orb.toAtomicOrbital(n=n, zeta=izeta)) - for ipol in range(1, npol+1): - orb = SphericalOrbital(l+1, None, R=first_zeta.R) + for ipol in range(1, npol + 1): + orb = SphericalOrbital(l + 1, None, R=first_zeta.R) orbs.extend(orb.toAtomicOrbital(n=n, zeta=ipol, P=True)) return tag, orbs @@ -1961,7 +2040,7 @@ def parse_next(): return atoms.get(specie, None) def _r_add_overlap(self, parent_call, M): - """ Internal routine to ensure that the overlap matrix is read and added to the matrix `M` """ + """Internal routine to ensure that the overlap matrix is read and added to the matrix `M`""" try: S = self.read_overlap() # Check for the same sparsity pattern @@ -1970,10 +2049,12 @@ def _r_add_overlap(self, parent_call, M): else: raise ValueError except Exception: - warn(f"{self!s} could not succesfully read the overlap matrix in {parent_call}.") + warn( + f"{self!s} could not succesfully read the overlap matrix in {parent_call}." + ) def read_density_matrix(self, *args, **kwargs): - """ Try and read density matrix by reading the <>.nc, <>.TSDE files, <>.DM (in that order) + """Try and read density matrix by reading the <>.nc, <>.TSDE files, <>.DM (in that order) One can limit the tried files to only one file by passing only a single file ending. @@ -1984,7 +2065,9 @@ def read_density_matrix(self, *args, **kwargs): the order of which to try and read the density matrix By default this is ``["nc", "TSDE", "DM"]``. """ - order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "TSDE", "DM"], []) + order = _parse_output_order( + kwargs.pop("order", None), True, ["nc", "TSDE", "DM"], [] + ) for f in order: DM = getattr(self, f"_r_density_matrix_{f.lower()}")(*args, **kwargs) if DM is not None: @@ -1993,7 +2076,7 @@ def read_density_matrix(self, *args, **kwargs): return None def _r_density_matrix_nc(self, *args, **kwargs): - """ Try and read the density matrix by reading the <>.nc """ + """Try and read the density matrix by reading the <>.nc""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".nc") _track_file(self._r_density_matrix_nc, f) DM = None @@ -2003,7 +2086,7 @@ def _r_density_matrix_nc(self, *args, **kwargs): return DM def _r_density_matrix_tsde(self, *args, **kwargs): - """ Read density matrix from the TSDE file """ + """Read density matrix from the TSDE file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".TSDE") _track_file(self._r_density_matrix_tsde, f) DM = None @@ -2016,7 +2099,7 @@ def _r_density_matrix_tsde(self, *args, **kwargs): return DM def _r_density_matrix_dm(self, *args, **kwargs): - """ Read density matrix from the DM file """ + """Read density matrix from the DM file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".DM") _track_file(self._r_density_matrix_dm, f) DM = None @@ -2029,7 +2112,7 @@ def _r_density_matrix_dm(self, *args, **kwargs): return DM def read_energy_density_matrix(self, *args, **kwargs): - """ Try and read energy density matrix by reading the <>.nc or <>.TSDE files (in that order) + """Try and read energy density matrix by reading the <>.nc or <>.TSDE files (in that order) One can limit the tried files to only one file by passing only a single file ending. @@ -2042,14 +2125,16 @@ def read_energy_density_matrix(self, *args, **kwargs): """ order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "TSDE"], []) for f in order: - EDM = getattr(self, f"_r_energy_density_matrix_{f.lower()}")(*args, **kwargs) + EDM = getattr(self, f"_r_energy_density_matrix_{f.lower()}")( + *args, **kwargs + ) if EDM is not None: _track(self.read_energy_density_matrix, f"found file {f}") return EDM return None def _r_energy_density_matrix_nc(self, *args, **kwargs): - """ Read energy density matrix by reading the <>.nc """ + """Read energy density matrix by reading the <>.nc""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".nc") _track_file(self._r_energy_density_matrix_nc, f) if f.is_file(): @@ -2057,7 +2142,7 @@ def _r_energy_density_matrix_nc(self, *args, **kwargs): return None def _r_energy_density_matrix_tsde(self, *args, **kwargs): - """ Read energy density matrix from the TSDE file """ + """Read energy density matrix from the TSDE file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".TSDE") _track_file(self._r_energy_density_matrix_tsde, f) EDM = None @@ -2070,7 +2155,7 @@ def _r_energy_density_matrix_tsde(self, *args, **kwargs): return EDM def read_overlap(self, *args, **kwargs): - """ Try and read the overlap matrix by reading the <>.nc, <>.TSHS files, <>.HSX, <>.onlyS (in that order) + """Try and read the overlap matrix by reading the <>.nc, <>.TSHS files, <>.HSX, <>.onlyS (in that order) One can limit the tried files to only one file by passing only a single file ending. @@ -2081,7 +2166,9 @@ def read_overlap(self, *args, **kwargs): the order of which to try and read the overlap matrix By default this is ``["nc", "TSHS", "HSX", "onlyS"]``. """ - order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "TSHS", "HSX", "onlyS"], []) + order = _parse_output_order( + kwargs.pop("order", None), True, ["nc", "TSHS", "HSX", "onlyS"], [] + ) for f in order: v = getattr(self, f"_r_overlap_{f.lower()}")(*args, **kwargs) if v is not None: @@ -2090,7 +2177,7 @@ def read_overlap(self, *args, **kwargs): return None def _r_overlap_nc(self, *args, **kwargs): - """ Read overlap from the nc file """ + """Read overlap from the nc file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".nc") _track_file(self._r_overlap_nc, f) if f.is_file(): @@ -2098,7 +2185,7 @@ def _r_overlap_nc(self, *args, **kwargs): return None def _r_overlap_tshs(self, *args, **kwargs): - """ Read overlap from the TSHS file """ + """Read overlap from the TSHS file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".TSHS") _track_file(self._r_overlap_tshs, f) S = None @@ -2110,7 +2197,7 @@ def _r_overlap_tshs(self, *args, **kwargs): return S def _r_overlap_hsx(self, *args, **kwargs): - """ Read overlap from the HSX file """ + """Read overlap from the HSX file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".HSX") _track_file(self._r_overlap_hsx, f) S = None @@ -2122,7 +2209,7 @@ def _r_overlap_hsx(self, *args, **kwargs): return S def _r_overlap_onlys(self, *args, **kwargs): - """ Read overlap from the onlyS file """ + """Read overlap from the onlyS file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".onlyS") _track_file(self._r_overlap_onlys, f) S = None @@ -2134,7 +2221,7 @@ def _r_overlap_onlys(self, *args, **kwargs): return S def read_hamiltonian(self, *args, **kwargs): - """ Try and read the Hamiltonian by reading the <>.nc, <>.TSHS files, <>.HSX (in that order) + """Try and read the Hamiltonian by reading the <>.nc, <>.TSHS files, <>.HSX (in that order) One can limit the tried files to only one file by passing only a single file ending. @@ -2145,7 +2232,9 @@ def read_hamiltonian(self, *args, **kwargs): the order of which to try and read the Hamiltonian. By default this is ``["nc", "TSHS", "HSX"]``. """ - order = _parse_output_order(kwargs.pop("order", None), True, ["nc", "TSHS", "HSX"], []) + order = _parse_output_order( + kwargs.pop("order", None), True, ["nc", "TSHS", "HSX"], [] + ) for f in order: H = getattr(self, f"_r_hamiltonian_{f.lower()}")(*args, **kwargs) if H is not None: @@ -2154,7 +2243,7 @@ def read_hamiltonian(self, *args, **kwargs): return None def _r_hamiltonian_nc(self, *args, **kwargs): - """ Read Hamiltonian from the nc file """ + """Read Hamiltonian from the nc file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".nc") _track_file(self._r_hamiltonian_nc, f) if f.is_file(): @@ -2162,7 +2251,7 @@ def _r_hamiltonian_nc(self, *args, **kwargs): return None def _r_hamiltonian_tshs(self, *args, **kwargs): - """ Read Hamiltonian from the TSHS file """ + """Read Hamiltonian from the TSHS file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".TSHS") _track_file(self._r_hamiltonian_tshs, f) H = None @@ -2174,7 +2263,7 @@ def _r_hamiltonian_tshs(self, *args, **kwargs): return H def _r_hamiltonian_hsx(self, *args, **kwargs): - """ Read Hamiltonian from the HSX file """ + """Read Hamiltonian from the HSX file""" f = self.dir_file(self.get("SystemLabel", default="siesta") + ".HSX") _track_file(self._r_hamiltonian_hsx, f) H = None @@ -2186,14 +2275,16 @@ def _r_hamiltonian_hsx(self, *args, **kwargs): H = hsxSileSiesta(f).read_hamiltonian(*args, **kwargs) Ef = self.read_fermi_level() if Ef is None: - info(f"{self!s}.read_hamiltonian from HSX file failed shifting to the Fermi-level.") + info( + f"{self!s}.read_hamiltonian from HSX file failed shifting to the Fermi-level." + ) else: H.shift(-Ef) return H @default_ArgumentParser(description="Manipulate a FDF file.") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" import argparse # We must by-pass this fdf-file for importing @@ -2203,7 +2294,9 @@ def ArgumentParser(self, p=None, *args, **kwargs): # The fdf parser is more complicated # It is based on different settings based on the - sp = p.add_subparsers(help="Determine which part of the fdf-file that should be processed.") + sp = p.add_subparsers( + help="Determine which part of the fdf-file that should be processed." + ) # Get the label which retains all the sub-modules label = self.get("SystemLabel", default="siesta") @@ -2215,13 +2308,13 @@ def label_file(suffix): # The default on all sub-parsers are the retrieval and setting namespace = default_namespace(_fdf=self, _fdf_first=True) - ep = sp.add_parser("edit", - help="Change or read and print data from the fdf file") + ep = sp.add_parser( + "edit", help="Change or read and print data from the fdf file" + ) # As the fdf may provide additional stuff, we do not add EVERYTHING from # the Geometry class. class FDFAdd(argparse.Action): - def __call__(self, parser, ns, values, option_string=None): key = values[0] val = values[1] @@ -2231,12 +2324,17 @@ def __call__(self, parser, ns, values, option_string=None): fd.write("\n\n# SISL added keywords\n") setattr(ns, "_fdf_first", False) ns._fdf.set(key, val) - ep.add_argument("--set", "-s", nargs=2, metavar=("KEY", "VALUE"), - action=FDFAdd, - help="Add a key to the FDF file. If it already exists it will be overwritten") - class FDFGet(argparse.Action): + ep.add_argument( + "--set", + "-s", + nargs=2, + metavar=("KEY", "VALUE"), + action=FDFAdd, + help="Add a key to the FDF file. If it already exists it will be overwritten", + ) + class FDFGet(argparse.Action): def __call__(self, parser, ns, value, option_string=None): # Retrieve the value in standard units # Currently, we write out the unit "as-is" @@ -2250,9 +2348,14 @@ def __call__(self, parser, ns, value, option_string=None): else: print(ns._fdf.print(value[0], val)) - ep.add_argument("--get", "-g", nargs=1, metavar="KEY", - action=FDFGet, - help="Print (to stdout) the value of the key in the FDF file.") + ep.add_argument( + "--get", + "-g", + nargs=1, + metavar="KEY", + action=FDFGet, + help="Print (to stdout) the value of the key in the FDF file.", + ) # If the XV file exists, it has precedence # of the contained geometry (we will issue @@ -2261,8 +2364,9 @@ def __call__(self, parser, ns, value, option_string=None): try: geom = self.read_geometry(True) - tmp_p = sp.add_parser("geom", - help="Edit the contained geometry in the file") + tmp_p = sp.add_parser( + "geom", help="Edit the contained geometry in the file" + ) tmp_p, tmp_ns = geom.ArgumentParser(tmp_p, *args, **kwargs) namespace = merge_instances(namespace, tmp_ns) except Exception: @@ -2271,27 +2375,32 @@ def __call__(self, parser, ns, value, option_string=None): f = label_file(".bands") if f.is_file(): - tmp_p = sp.add_parser("band", - help="Manipulate bands file from the Siesta simulation") - tmp_p, tmp_ns = sis.bandsSileSiesta(f).ArgumentParser(tmp_p, *args, **kwargs) + tmp_p = sp.add_parser( + "band", help="Manipulate bands file from the Siesta simulation" + ) + tmp_p, tmp_ns = sis.bandsSileSiesta(f).ArgumentParser( + tmp_p, *args, **kwargs + ) namespace = merge_instances(namespace, tmp_ns) f = label_file(".PDOS.xml") if f.is_file(): - tmp_p = sp.add_parser("pdos", - help="Manipulate PDOS.xml file from the Siesta simulation") + tmp_p = sp.add_parser( + "pdos", help="Manipulate PDOS.xml file from the Siesta simulation" + ) tmp_p, tmp_ns = sis.pdosSileSiesta(f).ArgumentParser(tmp_p, *args, **kwargs) namespace = merge_instances(namespace, tmp_ns) f = label_file(".EIG") if f.is_file(): - tmp_p = sp.add_parser("eig", - help="Manipulate EIG file from the Siesta simulation") + tmp_p = sp.add_parser( + "eig", help="Manipulate EIG file from the Siesta simulation" + ) tmp_p, tmp_ns = sis.eigSileSiesta(f).ArgumentParser(tmp_p, *args, **kwargs) namespace = merge_instances(namespace, tmp_ns) - #f = label + ".FA" - #if isfile(f): + # f = label + ".FA" + # if isfile(f): # tmp_p = sp.add_parser("force", # help="Manipulate FA file from the Siesta simulation") # tmp_p, tmp_ns = sis.faSileSiesta(f).ArgumentParser(tmp_p, *args, **kwargs) @@ -2299,36 +2408,43 @@ def __call__(self, parser, ns, value, option_string=None): f = label_file(".TBT.nc") if f.is_file(): - tmp_p = sp.add_parser("tbt", - help="Manipulate tbtrans output file") - tmp_p, tmp_ns = tbt.tbtncSileTBtrans(f).ArgumentParser(tmp_p, *args, **kwargs) + tmp_p = sp.add_parser("tbt", help="Manipulate tbtrans output file") + tmp_p, tmp_ns = tbt.tbtncSileTBtrans(f).ArgumentParser( + tmp_p, *args, **kwargs + ) namespace = merge_instances(namespace, tmp_ns) f = label_file(".TBT.Proj.nc") if f.is_file(): - tmp_p = sp.add_parser("tbt-proj", - help="Manipulate tbtrans projection output file") - tmp_p, tmp_ns = tbt.tbtprojncSileTBtrans(f).ArgumentParser(tmp_p, *args, **kwargs) + tmp_p = sp.add_parser( + "tbt-proj", help="Manipulate tbtrans projection output file" + ) + tmp_p, tmp_ns = tbt.tbtprojncSileTBtrans(f).ArgumentParser( + tmp_p, *args, **kwargs + ) namespace = merge_instances(namespace, tmp_ns) f = label_file(".PHT.nc") if f.is_file(): - tmp_p = sp.add_parser("pht", - help="Manipulate the phtrans output file") - tmp_p, tmp_ns = tbt.phtncSilePHtrans(f).ArgumentParser(tmp_p, *args, **kwargs) + tmp_p = sp.add_parser("pht", help="Manipulate the phtrans output file") + tmp_p, tmp_ns = tbt.phtncSilePHtrans(f).ArgumentParser( + tmp_p, *args, **kwargs + ) namespace = merge_instances(namespace, tmp_ns) f = label_file(".PHT.Proj.nc") if f.is_file(): - tmp_p = sp.add_parser("pht-proj", - help="Manipulate phtrans projection output file") - tmp_p, tmp_ns = tbt.phtprojncSilePHtrans(f).ArgumentParser(tmp_p, *args, **kwargs) + tmp_p = sp.add_parser( + "pht-proj", help="Manipulate phtrans projection output file" + ) + tmp_p, tmp_ns = tbt.phtprojncSilePHtrans(f).ArgumentParser( + tmp_p, *args, **kwargs + ) namespace = merge_instances(namespace, tmp_ns) f = label_file(".nc") if f.is_file(): - tmp_p = sp.add_parser("nc", - help="Manipulate Siesta NetCDF output file") + tmp_p = sp.add_parser("nc", help="Manipulate Siesta NetCDF output file") tmp_p, tmp_ns = sis.ncSileSiesta(f).ArgumentParser(tmp_p, *args, **kwargs) namespace = merge_instances(namespace, tmp_ns) diff --git a/src/sisl/io/siesta/kp.py b/src/sisl/io/siesta/kp.py index 5587174413..bcce731be3 100644 --- a/src/sisl/io/siesta/kp.py +++ b/src/sisl/io/siesta/kp.py @@ -10,19 +10,21 @@ from ..sile import add_sile, sile_fh_open, sile_raise_write from .sile import SileSiesta -__all__ = ['kpSileSiesta', 'rkpSileSiesta'] +__all__ = ["kpSileSiesta", "rkpSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') +Bohr2Ang = unit_convert("Bohr", "Ang") @set_module("sisl.io.siesta") class kpSileSiesta(SileSiesta): - """ k-points file in 1/Bohr units """ + """k-points file in 1/Bohr units""" @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def read_data(self, lattice=None): - """ Returns K-points from the file (note that these are in reciprocal units) + """Returns K-points from the file (note that these are in reciprocal units) Parameters ---------- @@ -51,8 +53,8 @@ def read_data(self, lattice=None): return np.dot(k, lattice.cell.T / (2 * np.pi)), w @sile_fh_open() - def write_data(self, k, weight, fmt='.9e'): - """ Writes K-points to file + def write_data(self, k, weight, fmt=".9e"): + """Writes K-points to file Parameters ---------- @@ -66,16 +68,18 @@ def write_data(self, k, weight, fmt='.9e'): sile_raise_write(self) nk = len(k) - self._write(f'{nk}\n') - _fmt = ('{:d}' + (' {:' + fmt + '}') * 4) + '\n' + self._write(f"{nk}\n") + _fmt = ("{:d}" + (" {:" + fmt + "}") * 4) + "\n" for i, (kk, w) in enumerate(zip(np.atleast_2d(k), weight)): self._write(_fmt.format(i + 1, kk[0], kk[1], kk[2], w)) @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def read_brillouinzone(self, lattice): - """ Returns K-points from the file (note that these are in reciprocal units) + """Returns K-points from the file (note that these are in reciprocal units) Parameters ---------- @@ -95,8 +99,8 @@ def read_brillouinzone(self, lattice): return bz @sile_fh_open() - def write_brillouinzone(self, bz, fmt='.9e'): - """ Writes BrillouinZone-points to file + def write_brillouinzone(self, bz, fmt=".9e"): + """Writes BrillouinZone-points to file Parameters ---------- @@ -112,7 +116,7 @@ def write_brillouinzone(self, bz, fmt='.9e'): @set_module("sisl.io.siesta") class rkpSileSiesta(kpSileSiesta): - """ Special k-point file with units in reciprocal lattice vectors + """Special k-point file with units in reciprocal lattice vectors Its main usage is as input for the kgrid.File fdf-option, in which case this file provides the k-points in the correct format. @@ -120,7 +124,7 @@ class rkpSileSiesta(kpSileSiesta): @sile_fh_open() def read_data(self): - """ Returns K-points from the file (note that these are in reciprocal units) + """Returns K-points from the file (note that these are in reciprocal units) Returns ------- @@ -139,9 +143,11 @@ def read_data(self): return k, w @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def read_brillouinzone(self, lattice): - """ Returns K-points from the file + """Returns K-points from the file Parameters ---------- @@ -161,8 +167,8 @@ def read_brillouinzone(self, lattice): return bz @sile_fh_open() - def write_brillouinzone(self, bz, fmt='.9e'): - """ Writes BrillouinZone-points to file + def write_brillouinzone(self, bz, fmt=".9e"): + """Writes BrillouinZone-points to file Parameters ---------- @@ -174,5 +180,5 @@ def write_brillouinzone(self, bz, fmt='.9e'): self.write_data(bz.k, bz.weight, fmt) -add_sile('KP', kpSileSiesta, gzip=True) -add_sile('RKP', rkpSileSiesta, gzip=True) +add_sile("KP", kpSileSiesta, gzip=True) +add_sile("RKP", rkpSileSiesta, gzip=True) diff --git a/src/sisl/io/siesta/orb_indx.py b/src/sisl/io/siesta/orb_indx.py index 9a21526ac6..dc7ca03a87 100644 --- a/src/sisl/io/siesta/orb_indx.py +++ b/src/sisl/io/siesta/orb_indx.py @@ -9,19 +9,19 @@ from ..sile import add_sile, sile_fh_open from .sile import SileSiesta -__all__ = ['orbindxSileSiesta'] +__all__ = ["orbindxSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') +Bohr2Ang = unit_convert("Bohr", "Ang") @set_module("sisl.io.siesta") class orbindxSileSiesta(SileSiesta): - """ Orbital information file """ + """Orbital information file""" @sile_fh_open() def read_lattice_nsc(self): - """ Reads the supercell number of supercell information """ + """Reads the supercell number of supercell information""" # First line contains no no_s line = self.readline().split() no_s = int(line[1]) @@ -46,7 +46,7 @@ def int_abs(i): @sile_fh_open() def read_basis(self, atoms=None): - """ Returns a set of atoms corresponding to the basis-sets in the ORB_INDX file + """Returns a set of atoms corresponding to the basis-sets in the ORB_INDX file The specie names have a short field in the ORB_INDX file, hence the name may not necessarily be the same as provided in the species block @@ -99,7 +99,7 @@ def crt_atom(i_s, spec, orbs): if i_s in specs: continue nlmz = list(map(int, line[5:9])) - P = line[9] == 'T' + P = line[9] == "T" rc = float(line[11]) * Bohr2Ang # Create the orbital o = AtomicOrbital(n=nlmz[0], l=nlmz[1], m=nlmz[2], zeta=nlmz[3], P=P, R=rc) @@ -113,4 +113,4 @@ def crt_atom(i_s, spec, orbs): return Atoms([atom[i] for i in specs]) -add_sile('ORB_INDX', orbindxSileSiesta, gzip=True) +add_sile("ORB_INDX", orbindxSileSiesta, gzip=True) diff --git a/src/sisl/io/siesta/pdos.py b/src/sisl/io/siesta/pdos.py index 0ddeda96a9..ce53b70083 100644 --- a/src/sisl/io/siesta/pdos.py +++ b/src/sisl/io/siesta/pdos.py @@ -24,25 +24,25 @@ from ..sile import add_sile, get_sile, sile_fh_open from .sile import SileSiesta -__all__ = ['pdosSileSiesta'] +__all__ = ["pdosSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') +Bohr2Ang = unit_convert("Bohr", "Ang") @set_module("sisl.io.siesta") class pdosSileSiesta(SileSiesta): - """ Projected DOS file with orbital information + """Projected DOS file with orbital information Data file containing the PDOS as calculated by Siesta. """ def read_geometry(self): - """ Read the geometry with coordinates and correct orbital counts """ + """Read the geometry with coordinates and correct orbital counts""" return self.read_data()[0] @sile_fh_open(True) def read_fermi_level(self): - """ Returns the fermi-level """ + """Returns the fermi-level""" # Get the element-tree root = xml_parse(self.fh).getroot() @@ -54,7 +54,7 @@ def read_fermi_level(self): @sile_fh_open(True) def read_data(self, as_dataarray=False): - r""" Returns data associated with the PDOS file + r"""Returns data associated with the PDOS file For spin-polarized calculations the returned values are up/down, orbitals, energy. For non-collinear calculations the returned values are sum/x/y/z, orbitals, energy. @@ -83,7 +83,9 @@ def read_data(self, as_dataarray=False): Ef = root.find("fermi_energy") E = arrayd(root.find("energy_values").text.split()) if Ef is None: - warn(f"{self!s}.read_data could not locate the Fermi-level in the XML tree, using E_F = 0. eV") + warn( + f"{self!s}.read_data could not locate the Fermi-level in the XML tree, using E_F = 0. eV" + ) else: Ef = float(Ef.text) E -= Ef @@ -93,10 +95,12 @@ def read_data(self, as_dataarray=False): xyz = [] atoms = [] atom_species = [] + def ensure_size(ia): while len(atom_species) <= ia: atom_species.append(None) xyz.append(None) + def ensure_size_orb(ia, i): while len(atoms) <= ia: atoms.append([]) @@ -104,6 +108,7 @@ def ensure_size_orb(ia, i): atoms[ia].append(None) if nspin == 4: + def process(D): tmp = np.empty(D.shape[0], D.dtype) tmp[:] = D[:, 3] @@ -112,18 +117,23 @@ def process(D): D[:, 1] = D[:, 2] D[:, 2] = tmp[:] return D + elif nspin == 2: + def process(D): tmp = D[:, 0] + D[:, 1] D[:, 1] = D[:, 0] - D[:, 1] D[:, 0] = tmp return D + else: + def process(D): return D if as_dataarray: import xarray as xr + if nspin == 1: spin = ["sum"] elif nspin == 2: @@ -135,21 +145,26 @@ def process(D): dims = ["E", "spin", "n", "l", "m", "zeta", "polarization"] shape = (ne, nspin, 1, 1, 1, 1, 1) + def to(o, DOS): # Coordinates for this dataarray - coords = [E, spin, - [o.n], [o.l], [o.m], [o.zeta], [o.P]] + coords = [E, spin, [o.n], [o.l], [o.m], [o.zeta], [o.P]] - return xr.DataArray(data=process(DOS).reshape(shape), - dims=dims, coords=coords, name="PDOS") + return xr.DataArray( + data=process(DOS).reshape(shape), + dims=dims, + coords=coords, + name="PDOS", + ) else: + def to(o, DOS): return process(DOS) + D = [] for orb in root.findall("orbital"): - # Short-hand function to retrieve integers for the attributes def oi(name): return int(orb.get(name)) @@ -214,7 +229,8 @@ def oi(name): D = np.moveaxis(np.stack(D, axis=0), 2, 0) return geom, E, D - @default_ArgumentParser(description=""" + @default_ArgumentParser( + description=""" Extract/Plot data from a PDOS/PDOS.xml file The arguments are parsed as they are passed to the command line; hence order is important. @@ -238,9 +254,10 @@ def oi(name): --spin x --atom all --out spin_x_all.dat --spin y --atom all --out spin_y_all.dat will store the spin x/y components of all atoms in spin_x_all.dat/spin_y_all.dat, respectively. -""") +""" + ) def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" # We limit the import to occur here import argparse @@ -257,7 +274,7 @@ def ArgumentParser(self, p=None, *args, **kwargs): comment = "Fermi-level unknown" def norm(geom, orbitals=None, norm="none"): - r""" Normalization factor depending on the input + r"""Normalization factor depending on the input The normalization can be performed in one of the below methods. In the following :math:`N` refers to the normalization constant @@ -290,7 +307,9 @@ def norm(geom, orbitals=None, norm="none"): elif norm in ["all", "atom", "orbital"]: NORM = geom.no else: - raise ValueError(f"norm error on norm keyword in when requesting normalization!") + raise ValueError( + f"norm error on norm keyword in when requesting normalization!" + ) # If the user requests all orbitals if orbitals is None: @@ -307,7 +326,7 @@ def norm(geom, orbitals=None, norm="none"): return NORM def _sum_filter(PDOS): - """ Default sum is the total DOS, no projection on directions """ + """Default sum is the total DOS, no projection on directions""" if PDOS.ndim == 2: # non-polarized return PDOS @@ -315,20 +334,23 @@ def _sum_filter(PDOS): # polarized return PDOS[0] return PDOS[0] - namespace = default_namespace(_geometry=geometry, - _E=E, - _PDOS=PDOS, - # The energy range of all data - _Erng=None, - _norm="none", - _PDOS_filter_name="total", - _PDOS_filter=_sum_filter, - _data=[], - _data_description=[], - _data_header=[]) + + namespace = default_namespace( + _geometry=geometry, + _E=E, + _PDOS=PDOS, + # The energy range of all data + _Erng=None, + _norm="none", + _PDOS_filter_name="total", + _PDOS_filter=_sum_filter, + _data=[], + _data_description=[], + _data_header=[], + ) def ensure_E(func): - """ This decorater ensures that E is the first element in the _data container """ + """This decorater ensures that E is the first element in the _data container""" def assign_E(self, *args, **kwargs): ns = args[1] @@ -337,13 +359,14 @@ def assign_E(self, *args, **kwargs): ns._data.append(ns._E[ns._Erng].flatten()) ns._data_header.append("Energy[eV]") return func(self, *args, **kwargs) + return assign_E class ERange(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): E = ns._E Emap = strmap(float, value, E.min(), E.max()) + def Eindex(e): return np.abs(E - e).argmin() @@ -354,32 +377,42 @@ def Eindex(e): ns._Erng = None return elif begin is None: - E.append(range(Eindex(end)+1)) + E.append(range(Eindex(end) + 1)) elif end is None: E.append(range(Eindex(begin), len(E))) else: - E.append(range(Eindex(begin), Eindex(end)+1)) + E.append(range(Eindex(begin), Eindex(end) + 1)) # Issuing unique also sorts the entries ns._Erng = np.unique(arrayi(E).flatten()) - p.add_argument("--energy", "-E", action=ERange, - help="""Denote the sub-section of energies that are extracted: "-1:0,1:2" [eV] + + p.add_argument( + "--energy", + "-E", + action=ERange, + help="""Denote the sub-section of energies that are extracted: "-1:0,1:2" [eV] - This flag takes effect on all energy-resolved quantities and is reset whenever --plot or --out is called""") + This flag takes effect on all energy-resolved quantities and is reset whenever --plot or --out is called""", + ) # The normalization method class NormAction(argparse.Action): - @collect_action def __call__(self, parser, ns, value, option_string=None): ns._norm = value - p.add_argument("--norm", "-N", action=NormAction, default="atom", - choices=["none", "atom", "orbital", "all"], - help="""Specify the normalization method; "none") no normalization, "atom") total orbitals in selected atoms, + + p.add_argument( + "--norm", + "-N", + action=NormAction, + default="atom", + choices=["none", "atom", "orbital", "all"], + help="""Specify the normalization method; "none") no normalization, "atom") total orbitals in selected atoms, "orbital") selected orbitals or "all") all orbitals. Will only take effect on subsequent --atom ranges. - This flag is reset whenever --plot or --out is called""") + This flag is reset whenever --plot or --out is called""", + ) if PDOS.ndim == 2: # no spin is possible @@ -387,66 +420,91 @@ def __call__(self, parser, ns, value, option_string=None): elif PDOS.shape[0] == 2: # Add a spin-action class Spin(argparse.Action): - @collect_action def __call__(self, parser, ns, value, option_string=None): value = value[0].lower() if value in ("up", "u"): name = "up" + def _filter(PDOS): return (PDOS[0] + PDOS[1]) / 2 + elif value in ("down", "dn", "dw", "d"): name = "down" + def _filter(PDOS): return (PDOS[0] - PDOS[1]) / 2 + elif value in ("sum", "+", "total"): name = "total" + def _filter(PDOS): return PDOS[0] + elif value in ("z", "spin"): name = "z" + def _filter(PDOS): return PDOS[1] + else: - raise ValueError(f"Wrong argument for --spin [up, down, sum, z], found {value}") + raise ValueError( + f"Wrong argument for --spin [up, down, sum, z], found {value}" + ) ns._PDOS_filter_name = name ns._PDOS_filter = _filter - p.add_argument("--spin", "-S", action=Spin, nargs=1, - help="Which spin-component to store, up/u, down/d, z/spin or sum/+/total") + + p.add_argument( + "--spin", + "-S", + action=Spin, + nargs=1, + help="Which spin-component to store, up/u, down/d, z/spin or sum/+/total", + ) elif PDOS.shape[0] == 4: # Add a spin-action class Spin(argparse.Action): - @collect_action def __call__(self, parser, ns, value, option_string=None): value = value[0].lower() if value in ("sum", "+", "total"): name = "total" + def _filter(PDOS): return PDOS[0] + else: # the stuff must be a range of directions # so simply put it in idx = list(map(direction, value)) name = value + def _filter(PDOS): return PDOS[idx].sum(0) + ns._PDOS_filter_name = name ns._PDOS_filter = _filter - p.add_argument("--spin", "-S", action=Spin, nargs=1, - help="Which spin-component to store, sum/+/total, x, y, z or a sum of either of the directions xy, zx etc.") + + p.add_argument( + "--spin", + "-S", + action=Spin, + nargs=1, + help="Which spin-component to store, sum/+/total, x, y, z or a sum of either of the directions xy, zx etc.", + ) def parse_atom_range(geom, value): if value.lower() in ("all", ":"): return np.arange(geom.no), "all" - value = ",".join(# ensure only single commas (no space between them) - "".join(# ensure no empty whitespaces - ",".join(# join different lines with a comma - value.splitlines()) - .split()) - .split(",")) + value = ",".join( # ensure only single commas (no space between them) + "".join( # ensure no empty whitespaces + ",".join( # join different lines with a comma + value.splitlines() + ).split() + ).split(",") + ) # Sadly many shell interpreters does not # allow simple [] because they are expansion tokens @@ -491,14 +549,15 @@ def parse_atom_range(geom, value): print(f" 1-{len(geometry)}") print("Input atoms:") print(" ", value) - raise ValueError("Atomic/Orbital requests are not fully included in the device region.") + raise ValueError( + "Atomic/Orbital requests are not fully included in the device region." + ) # Add one to make the c-index equivalent to the f-index return np.concatenate(orbs).flatten(), value # Try and add the atomic specification class AtomRange(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, value, option_string=None): @@ -515,19 +574,29 @@ def __call__(self, parser, ns, value, option_string=None): DOS = "PDOS" if ns._PDOS_filter_name is not None: - ns._data_header.append(f"{DOS}[spin={ns._PDOS_filter_name}:{value}][1/eV]") - ns._data_description.append(f"Column {index} is the sum of spin={ns._PDOS_filter_name} on atoms[orbs] {value} with normalization 1/{scale}") + ns._data_header.append( + f"{DOS}[spin={ns._PDOS_filter_name}:{value}][1/eV]" + ) + ns._data_description.append( + f"Column {index} is the sum of spin={ns._PDOS_filter_name} on atoms[orbs] {value} with normalization 1/{scale}" + ) else: ns._data_header.append(f"{DOS}[{value}][1/eV]") - ns._data_description.append(f"Column {index} is the total PDOS on atoms[orbs] {value} with normalization 1/{scale}") + ns._data_description.append( + f"Column {index} is the total PDOS on atoms[orbs] {value} with normalization 1/{scale}" + ) - p.add_argument("--atom", "-a", type=str, action=AtomRange, - help="""Limit orbital resolved PDOS to a sub-set of atoms/orbitals: "1-2[3,4]" will yield the 1st and 2nd atom and their 3rd and fourth orbital. Multiple comma-separated specifications are allowed. Note that some shells does not allow [] as text-input (due to expansion), {, [ or * are allowed orbital delimiters. + p.add_argument( + "--atom", + "-a", + type=str, + action=AtomRange, + help="""Limit orbital resolved PDOS to a sub-set of atoms/orbitals: "1-2[3,4]" will yield the 1st and 2nd atom and their 3rd and fourth orbital. Multiple comma-separated specifications are allowed. Note that some shells does not allow [] as text-input (due to expansion), {, [ or * are allowed orbital delimiters. -Multiple options will create a new column/line in output, the --norm and --E should be before any of these arguments""") +Multiple options will create a new column/line in output, the --norm and --E should be before any of these arguments""", + ) class Out(argparse.Action): - @run_actions def __call__(self, parser, ns, value, option_string=None): out = value[0] @@ -549,14 +618,19 @@ def __call__(self, parser, ns, value, option_string=None): ns._data_header.append("Energy[eV]") ns._data.append(ns._PDOS_filter(ns._PDOS).sum(0)) if ns._PDOS_filter_name is not None: - ns._data_header.append(f"DOS[spin={ns._PDOS_filter_name}][1/eV]") + ns._data_header.append( + f"DOS[spin={ns._PDOS_filter_name}][1/eV]" + ) else: ns._data_header.append("DOS[1/eV]") from sisl.io import tableSile - tableSile(out, mode="w").write(*ns._data, - comment=[comment] + ns._data_description, - header=ns._data_header) + + tableSile(out, mode="w").write( + *ns._data, + comment=[comment] + ns._data_description, + header=ns._data_header, + ) # Clean all data ns._norm = "none" ns._data = [] @@ -566,31 +640,38 @@ def __call__(self, parser, ns, value, option_string=None): ns._PDOS_filter_name = None ns._PDOS_filter = _sum_filter ns._Erng = None - p.add_argument("--out", "-o", nargs=1, action=Out, - help="Store currently collected PDOS (at its current invocation) to the out file.") - class Plot(argparse.Action): + p.add_argument( + "--out", + "-o", + nargs=1, + action=Out, + help="Store currently collected PDOS (at its current invocation) to the out file.", + ) + class Plot(argparse.Action): @run_actions def __call__(self, parser, ns, value, option_string=None): - if len(ns._data) == 0: ns._data.append(ns._E) ns._data_header.append("Energy[eV]") ns._data.append(ns._PDOS_filter(ns._PDOS).sum(0)) if ns._PDOS_filter_name is not None: - ns._data_header.append(f"DOS[spin={ns._PDOS_filter_name}][1/eV]") + ns._data_header.append( + f"DOS[spin={ns._PDOS_filter_name}][1/eV]" + ) else: ns._data_header.append("DOS[1/eV]") from matplotlib import pyplot as plt + plt.figure() def _get_header(header): - header = (header - .replace("PDOS", "") - .replace("DOS", "") - .replace("[1/eV]", "") + header = ( + header.replace("PDOS", "") + .replace("DOS", "") + .replace("[1/eV]", "") ) if len(header) == 0: return "Total" @@ -604,7 +685,12 @@ def _get_header(header): if len(ns._data) > 2: kwargs["alpha"] = 0.6 for i in range(1, len(ns._data)): - plt.plot(ns._data[0], ns._data[i], label=_get_header(ns._data_header[i]), **kwargs) + plt.plot( + ns._data[0], + ns._data[i], + label=_get_header(ns._data_header[i]), + **kwargs, + ) plt.ylabel("DOS [1/eV]") if "unknown" in comment: @@ -627,8 +713,15 @@ def _get_header(header): ns._PDOS_filter_name = None ns._PDOS_filter = _sum_filter ns._Erng = None - p.add_argument("--plot", "-p", action=Plot, nargs="?", metavar="FILE", - help="Plot the currently collected information (at its current invocation).") + + p.add_argument( + "--plot", + "-p", + action=Plot, + nargs="?", + metavar="FILE", + help="Plot the currently collected information (at its current invocation).", + ) return p, namespace diff --git a/src/sisl/io/siesta/siesta_grid.py b/src/sisl/io/siesta/siesta_grid.py index 044934d4ce..6a73d56b85 100644 --- a/src/sisl/io/siesta/siesta_grid.py +++ b/src/sisl/io/siesta/siesta_grid.py @@ -15,46 +15,47 @@ from ..sile import SileError, add_sile, sile_raise_write from .sile import SileCDFSiesta -__all__ = ['gridncSileSiesta'] +__all__ = ["gridncSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ry2eV = unit_convert('Ry', 'eV') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ry2eV = unit_convert("Ry", "eV") @set_module("sisl.io.siesta") class gridncSileSiesta(SileCDFSiesta): - """ NetCDF real-space grid file + """NetCDF real-space grid file The grid sile will automatically convert the units from Siesta units (Bohr, Ry) to sisl units (Ang, eV) provided the correct extension is present. """ def read_lattice(self): - """ Returns a Lattice object from a Siesta.grid.nc file - """ - cell = np.array(self._value('cell'), np.float64) + """Returns a Lattice object from a Siesta.grid.nc file""" + cell = np.array(self._value("cell"), np.float64) # Yes, this is ugly, I really should implement my unit-conversion tool cell *= Bohr2Ang cell.shape = (3, 3) return Lattice(cell) - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def write_lattice(self, lattice): - """ Write a supercell to the grid.nc file """ + """Write a supercell to the grid.nc file""" sile_raise_write(self) # Create initial dimensions - self._crt_dim(self, 'xyz', 3) - self._crt_dim(self, 'abc', 3) + self._crt_dim(self, "xyz", 3) + self._crt_dim(self, "abc", 3) - v = self._crt_var(self, 'cell', 'f8', ('abc', 'xyz')) - v.info = 'Unit cell' - v.unit = 'Bohr' + v = self._crt_var(self, "cell", "f8", ("abc", "xyz")) + v.info = "Unit cell" + v.unit = "Bohr" v[:, :] = lattice.cell[:, :] / Bohr2Ang - def read_grid(self, index=0, name='gridfunc', *args, **kwargs): - """ Reads a grid in the current Siesta.grid.nc file + def read_grid(self, index=0, name="gridfunc", *args, **kwargs): + """Reads a grid in the current Siesta.grid.nc file Enables the reading and processing of the grids created by Siesta @@ -82,28 +83,29 @@ def read_grid(self, index=0, name='gridfunc', *args, **kwargs): # So the first one should be ElectrostaticPotential try: # <>.grid.nc - base = f.split('.')[-3] + base = f.split(".")[-3] except Exception: - base = 'None' + base = "None" # Unit-conversion - BohrC2AngC = Bohr2Ang ** 3 - - unit = {'Rho': 1. / BohrC2AngC, - 'DeltaRho': 1. / BohrC2AngC, - 'RhoXC': 1. / BohrC2AngC, - 'RhoInit': 1. / BohrC2AngC, - 'Chlocal': 1. / BohrC2AngC, - 'TotalCharge': 1. / BohrC2AngC, - 'BaderCharge': 1. / BohrC2AngC, - 'ElectrostaticPotential': Ry2eV, - 'TotalPotential': Ry2eV, - 'Vna': Ry2eV, + BohrC2AngC = Bohr2Ang**3 + + unit = { + "Rho": 1.0 / BohrC2AngC, + "DeltaRho": 1.0 / BohrC2AngC, + "RhoXC": 1.0 / BohrC2AngC, + "RhoInit": 1.0 / BohrC2AngC, + "Chlocal": 1.0 / BohrC2AngC, + "TotalCharge": 1.0 / BohrC2AngC, + "BaderCharge": 1.0 / BohrC2AngC, + "ElectrostaticPotential": Ry2eV, + "TotalPotential": Ry2eV, + "Vna": Ry2eV, }.get(base, None) # Fall-back if unit is None: - unit = 1. + unit = 1.0 show_info = True else: show_info = False @@ -112,18 +114,23 @@ def read_grid(self, index=0, name='gridfunc', *args, **kwargs): lattice = self.read_lattice().swapaxes(0, 2) # Create the grid - nx = len(self._dimension('n1')) - ny = len(self._dimension('n2')) - nz = len(self._dimension('n3')) + nx = len(self._dimension("n1")) + ny = len(self._dimension("n2")) + nz = len(self._dimension("n3")) if name is None: - v = self._variable('gridfunc') + v = self._variable("gridfunc") else: v = self._variable(name) # Create the grid, Siesta uses periodic, always - grid = Grid([nz, ny, nx], bc=Grid.PERIODIC, lattice=lattice, dtype=v.dtype, - geometry=kwargs.get("geometry", None)) + grid = Grid( + [nz, ny, nx], + bc=Grid.PERIODIC, + lattice=lattice, + dtype=v.dtype, + geometry=kwargs.get("geometry", None), + ) if v.ndim == 3: grid.grid[:, :, :] = v[:, :, :] * unit @@ -133,32 +140,34 @@ def read_grid(self, index=0, name='gridfunc', *args, **kwargs): grid_reduce_indices(v, np.array(index) * unit, axis=0, out=grid.grid) if show_info: - info(f"{self.__class__.__name__}.read_grid cannot determine the units of the grid. " - "The units may not be in sisl units.") + info( + f"{self.__class__.__name__}.read_grid cannot determine the units of the grid. " + "The units may not be in sisl units." + ) # Read the grid, we want the z-axis to be the fastest # looping direction, hence x,y,z == 0,1,2 return grid.swapaxes(0, 2) def write_grid(self, grid, spin=0, nspin=None, **kwargs): - """ Write a grid to the grid.nc file """ + """Write a grid to the grid.nc file""" sile_raise_write(self) # Default to *index* variable - spin = kwargs.get('index', spin) + spin = kwargs.get("index", spin) self.write_lattice(grid.lattice) if nspin is not None: - self._crt_dim(self, 'spin', nspin) + self._crt_dim(self, "spin", nspin) - self._crt_dim(self, 'n1', grid.shape[0]) - self._crt_dim(self, 'n2', grid.shape[1]) - self._crt_dim(self, 'n3', grid.shape[2]) + self._crt_dim(self, "n1", grid.shape[0]) + self._crt_dim(self, "n2", grid.shape[1]) + self._crt_dim(self, "n3", grid.shape[2]) if nspin is None: - v = self._crt_var(self, "gridfunc", grid.dtype, ('n3', 'n2', 'n1')) + v = self._crt_var(self, "gridfunc", grid.dtype, ("n3", "n2", "n1")) else: - v = self._crt_var(self, "gridfunc", grid.dtype, ('spin', 'n3', 'n2', 'n1')) + v = self._crt_var(self, "gridfunc", grid.dtype, ("spin", "n3", "n2", "n1")) v.info = "Grid function" if nspin is None: @@ -167,4 +176,4 @@ def write_grid(self, grid, spin=0, nspin=None, **kwargs): v[spin, :, :, :] = np.swapaxes(grid.grid, 0, 2) -add_sile('grid.nc', gridncSileSiesta) +add_sile("grid.nc", gridncSileSiesta) diff --git a/src/sisl/io/siesta/siesta_nc.py b/src/sisl/io/siesta/siesta_nc.py index 6c03568c65..f3bc0aec1e 100644 --- a/src/sisl/io/siesta/siesta_nc.py +++ b/src/sisl/io/siesta/siesta_nc.py @@ -34,25 +34,25 @@ has_fortran_module = False -__all__ = ['ncSileSiesta'] +__all__ = ["ncSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ry2eV = unit_convert('Ry', 'eV') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ry2eV = unit_convert("Ry", "eV") @set_module("sisl.io.siesta") class ncSileSiesta(SileCDFSiesta): - """ Generic NetCDF output file containing a large variety of information """ + """Generic NetCDF output file containing a large variety of information""" @lru_cache(maxsize=1) def read_lattice_nsc(self): - """ Returns number of supercell connections """ - return np.array(self._value('nsc'), np.int32) + """Returns number of supercell connections""" + return np.array(self._value("nsc"), np.int32) @lru_cache(maxsize=1) def read_lattice(self): - """ Returns a Lattice object from a Siesta.nc file """ - cell = np.array(self._value('cell'), np.float64) + """Returns a Lattice object from a Siesta.nc file""" + cell = np.array(self._value("cell"), np.float64) # Yes, this is ugly, I really should implement my unit-conversion tool cell *= Bohr2Ang cell.shape = (3, 3) @@ -63,18 +63,17 @@ def read_lattice(self): @lru_cache(maxsize=1) def read_basis(self): - """ Returns a set of atoms corresponding to the basis-sets in the nc file """ - if 'BASIS' not in self.groups: + """Returns a set of atoms corresponding to the basis-sets in the nc file""" + if "BASIS" not in self.groups: return None - basis = self.groups['BASIS'] + basis = self.groups["BASIS"] atom = [None] * len(basis.groups) for a_str in basis.groups: a = basis.groups[a_str] - if 'orbnl_l' not in a.variables: - + if "orbnl_l" not in a.variables: # Do the easy thing. # Get number of orbitals @@ -87,23 +86,22 @@ def read_basis(self): continue # Retrieve values - orb_l = a.variables['orbnl_l'][:] # angular quantum number - orb_n = a.variables['orbnl_n'][:] # principal quantum number - orb_z = a.variables['orbnl_z'][:] # zeta - orb_P = a.variables['orbnl_ispol'][:] > 0 # polarization shell, or not - orb_q0 = a.variables['orbnl_pop'][:] # q0 for the orbitals - orb_delta = a.variables['delta'][:] # delta for the functions - orb_psi = a.variables['orb'][:, :] + orb_l = a.variables["orbnl_l"][:] # angular quantum number + orb_n = a.variables["orbnl_n"][:] # principal quantum number + orb_z = a.variables["orbnl_z"][:] # zeta + orb_P = a.variables["orbnl_ispol"][:] > 0 # polarization shell, or not + orb_q0 = a.variables["orbnl_pop"][:] # q0 for the orbitals + orb_delta = a.variables["delta"][:] # delta for the functions + orb_psi = a.variables["orb"][:, :] # Now loop over all orbitals orbital = [] # Number of basis-orbitals (before m-expansion) - no = len(a.dimensions['norbs']) + no = len(a.dimensions["norbs"]) # All orbital data for io in range(no): - n = orb_n[io] l = orb_l[io] z = orb_z[io] @@ -122,7 +120,7 @@ def read_basis(self): # The fact that we have to have it normalized means that we need # to convert psi /sqrt(Bohr**3) -> /sqrt(Ang**3) # \int psi^\dagger psi == 1 - psi = orb_psi[io, :] * r ** l / Bohr2Ang ** (3./2.) + psi = orb_psi[io, :] * r**l / Bohr2Ang ** (3.0 / 2.0) # Create the sphericalorbital and then the atomicorbital sorb = SphericalOrbital(l, (r * Bohr2Ang, psi), orb_q0[io]) @@ -141,17 +139,17 @@ def read_basis(self): @lru_cache(maxsize=1) def read_geometry(self): - """ Returns Geometry object from a Siesta.nc file """ + """Returns Geometry object from a Siesta.nc file""" # Read supercell lattice = self.read_lattice() - xyz = np.array(self._value('xa'), np.float64) + xyz = np.array(self._value("xa"), np.float64) xyz.shape = (-1, 3) - if 'BASIS' in self.groups: + if "BASIS" in self.groups: basis = self.read_basis() - species = self.groups['BASIS'].variables['basis'][:] - 1 + species = self.groups["BASIS"].variables["basis"][:] - 1 atom = Atoms([basis[i] for i in species]) else: atom = Atom(1) @@ -164,13 +162,13 @@ def read_geometry(self): @lru_cache(maxsize=1) def read_force(self): - """ Returns a vector with final forces contained. """ - return np.array(self._value('fa')) * Ry2eV / Bohr2Ang + """Returns a vector with final forces contained.""" + return np.array(self._value("fa")) * Ry2eV / Bohr2Ang @lru_cache(maxsize=1) def read_fermi_level(self): - """ Returns the fermi-level """ - return self._value('Ef')[:] * Ry2eV + """Returns the fermi-level""" + return self._value("Ef")[:] * Ry2eV def _read_class(self, cls, dim=1, **kwargs): # Get the default spin channel @@ -178,50 +176,50 @@ def _read_class(self, cls, dim=1, **kwargs): geom = self.read_geometry() # Populate the things - sp = self.groups['SPARSE'] + sp = self.groups["SPARSE"] # Now create the tight-binding stuff (we re-create the # array, hence just allocate the smallest amount possible) C = cls(geom, dim, nnzpr=1) - C._csr.ncol = np.array(sp.variables['n_col'][:], np.int32) + C._csr.ncol = np.array(sp.variables["n_col"][:], np.int32) # Update maximum number of connections (in case future stuff happens) C._csr.ptr = _ncol_to_indptr(C._csr.ncol) - C._csr.col = np.array(sp.variables['list_col'][:], np.int32) - 1 + C._csr.col = np.array(sp.variables["list_col"][:], np.int32) - 1 # Copy information over C._csr._nnz = len(C._csr.col) C._csr._D = np.empty([C._csr.ptr[-1], dim], np.float64) # Convert from isc to sisl isc - _csr_from_sc_off(C.geometry, sp.variables['isc_off'][:, :], C._csr) + _csr_from_sc_off(C.geometry, sp.variables["isc_off"][:, :], C._csr) return C def _read_class_spin(self, cls, **kwargs): # Get the default spin channel - spin = len(self._dimension('spin')) + spin = len(self._dimension("spin")) # First read the geometry geom = self.read_geometry() # Populate the things - sp = self.groups['SPARSE'] + sp = self.groups["SPARSE"] # Since we may read in an orthogonal basis (stored in a Siesta compliant file) # we can check whether it is orthogonal by checking the sum of the absolute S # I.e. whether only diagonal elements are present. - S = np.array(sp.variables['S'][:], np.float64) + S = np.array(sp.variables["S"][:], np.float64) orthogonal = np.abs(S).sum() == geom.no # Now create the tight-binding stuff (we re-create the # array, hence just allocate the smallest amount possible) C = cls(geom, spin, nnzpr=1, orthogonal=orthogonal) - C._csr.ncol = np.array(sp.variables['n_col'][:], np.int32) + C._csr.ncol = np.array(sp.variables["n_col"][:], np.int32) # Update maximum number of connections (in case future stuff happens) C._csr.ptr = _ncol_to_indptr(C._csr.ncol) - C._csr.col = np.array(sp.variables['list_col'][:], np.int32) - 1 + C._csr.col = np.array(sp.variables["list_col"][:], np.int32) - 1 # Copy information over C._csr._nnz = len(C._csr.col) @@ -232,62 +230,66 @@ def _read_class_spin(self, cls, **kwargs): C._csr._D[:, C.S_idx] = S # Convert from isc to sisl isc - _csr_from_sc_off(C.geometry, sp.variables['isc_off'][:, :], C._csr) + _csr_from_sc_off(C.geometry, sp.variables["isc_off"][:, :], C._csr) return C def read_overlap(self, **kwargs): - """ Returns a overlap matrix from the underlying NetCDF file """ + """Returns a overlap matrix from the underlying NetCDF file""" S = self._read_class(Overlap, **kwargs) - sp = self.groups['SPARSE'] - S._csr._D[:, 0] = sp.variables['S'][:] + sp = self.groups["SPARSE"] + S._csr._D[:, 0] = sp.variables["S"][:] return S.transpose(sort=kwargs.get("sort", True)) def read_hamiltonian(self, **kwargs): - """ Returns a Hamiltonian from the underlying NetCDF file """ + """Returns a Hamiltonian from the underlying NetCDF file""" H = self._read_class_spin(Hamiltonian, **kwargs) - sp = self.groups['SPARSE'] - if sp.variables['H'].unit != 'Ry': - raise SileError(f'{self}.read_hamiltonian requires the stored matrix to be in Ry!') + sp = self.groups["SPARSE"] + if sp.variables["H"].unit != "Ry": + raise SileError( + f"{self}.read_hamiltonian requires the stored matrix to be in Ry!" + ) for i in range(len(H.spin)): - H._csr._D[:, i] = sp.variables['H'][i, :] * Ry2eV + H._csr._D[:, i] = sp.variables["H"][i, :] * Ry2eV # fix siesta specific notation _mat_spin_convert(H) # Shift to the Fermi-level - Ef = - self._value('Ef')[:] * Ry2eV + Ef = -self._value("Ef")[:] * Ry2eV H.shift(Ef) return H.transpose(spin=False, sort=kwargs.get("sort", True)) def read_dynamical_matrix(self, **kwargs): - """ Returns a dynamical matrix from the underlying NetCDF file + """Returns a dynamical matrix from the underlying NetCDF file This assumes that the dynamical matrix is stored in the field "H" as would the Hamiltonian. This is counter-intuitive but is required when using PHtrans. """ D = self._read_class_spin(DynamicalMatrix, **kwargs) - sp = self.groups['SPARSE'] - if sp.variables['H'].unit != 'Ry**2': - raise SileError(f'{self}.read_dynamical_matrix requires the stored matrix to be in Ry**2!') - D._csr._D[:, 0] = sp.variables['H'][0, :] * Ry2eV ** 2 + sp = self.groups["SPARSE"] + if sp.variables["H"].unit != "Ry**2": + raise SileError( + f"{self}.read_dynamical_matrix requires the stored matrix to be in Ry**2!" + ) + D._csr._D[:, 0] = sp.variables["H"][0, :] * Ry2eV**2 return D.transpose(sort=kwargs.get("sort", True)) def read_density_matrix(self, **kwargs): - """ Returns a density matrix from the underlying NetCDF file """ + """Returns a density matrix from the underlying NetCDF file""" # This also adds the spin matrix DM = self._read_class_spin(DensityMatrix, **kwargs) - sp = self.groups['SPARSE'] + sp = self.groups["SPARSE"] for i in range(len(DM.spin)): - DM._csr._D[:, i] = sp.variables['DM'][i, :] + DM._csr._D[:, i] = sp.variables["DM"][i, :] # fix siesta specific notation _mat_spin_convert(DM) @@ -295,19 +297,19 @@ def read_density_matrix(self, **kwargs): return DM.transpose(spin=False, sort=kwargs.get("sort", True)) def read_energy_density_matrix(self, **kwargs): - """ Returns energy density matrix from the underlying NetCDF file """ + """Returns energy density matrix from the underlying NetCDF file""" EDM = self._read_class_spin(EnergyDensityMatrix, **kwargs) # Shift to the Fermi-level - Ef = self._value('Ef')[:] * Ry2eV + Ef = self._value("Ef")[:] * Ry2eV if Ef.size == 1: Ef = np.tile(Ef, 2) - sp = self.groups['SPARSE'] + sp = self.groups["SPARSE"] for i in range(len(EDM.spin)): - EDM._csr._D[:, i] = sp.variables['EDM'][i, :] * Ry2eV - if i < 2 and 'DM' in sp.variables: - EDM._csr._D[:, i] -= sp.variables['DM'][i, :] * Ef[i] + EDM._csr._D[:, i] = sp.variables["EDM"][i, :] * Ry2eV + if i < 2 and "DM" in sp.variables: + EDM._csr._D[:, i] -= sp.variables["DM"][i, :] * Ef[i] # fix siesta specific notation _mat_spin_convert(EDM) @@ -315,32 +317,32 @@ def read_energy_density_matrix(self, **kwargs): return EDM.transpose(spin=False, sort=kwargs.get("sort", True)) def read_force_constant(self): - """ Reads the force-constant stored in the nc file + """Reads the force-constant stored in the nc file Returns ------- force constants : numpy.ndarray with 5 dimensions containing all the forces. The 2nd dimensions contains contains the directions, and 3rd dimensions contains -/+ displacements. """ - if not 'FC' in self.groups: - raise SislError(f'{self}.read_force_constant cannot find the FC group.') - fc = self.groups['FC'] + if not "FC" in self.groups: + raise SislError(f"{self}.read_force_constant cannot find the FC group.") + fc = self.groups["FC"] - disp = fc.variables['disp'][0] * Bohr2Ang - f0 = fc.variables['fa0'][:, :] - fc = (fc.variables['fa'][:, :, :, :, :] - f0.reshape(1, 1, 1, -1, 3)) / disp + disp = fc.variables["disp"][0] * Bohr2Ang + f0 = fc.variables["fa0"][:, :] + fc = (fc.variables["fa"][:, :, :, :, :] - f0.reshape(1, 1, 1, -1, 3)) / disp fc[:, :, 1, :, :] *= -1 return fc * Ry2eV / Bohr2Ang @property @lru_cache(maxsize=1) def grids(self): - """ Return a list of available grids in this file. """ + """Return a list of available grids in this file.""" - return list(self.groups['GRID'].variables) + return list(self.groups["GRID"].variables) def read_grid(self, name, index=0, **kwargs): - """ Reads a grid in the current Siesta.nc file + """Reads a grid in the current Siesta.nc file Enables the reading and processing of the grids created by Siesta @@ -360,12 +362,12 @@ def read_grid(self, name, index=0, **kwargs): geom = self.read_geometry() # Shorthand - g = self.groups['GRID'] + g = self.groups["GRID"] # Create the grid - nx = len(g.dimensions['nx']) - ny = len(g.dimensions['ny']) - nz = len(g.dimensions['nz']) + nx = len(g.dimensions["nx"]) + ny = len(g.dimensions["ny"]) + nz = len(g.dimensions["nz"]) # Shorthand variable name v = g.variables[name] @@ -374,16 +376,17 @@ def read_grid(self, name, index=0, **kwargs): grid = Grid([nz, ny, nx], bc=Grid.PERIODIC, geometry=geom, dtype=v.dtype) # Unit-conversion - BohrC2AngC = Bohr2Ang ** 3 - - unit = {'Rho': 1. / BohrC2AngC, - 'RhoInit': 1. / BohrC2AngC, - 'RhoTot': 1. / BohrC2AngC, - 'RhoDelta': 1. / BohrC2AngC, - 'RhoXC': 1. / BohrC2AngC, - 'RhoBader': 1. / BohrC2AngC, - 'Chlocal': 1. / BohrC2AngC, - }.get(name, 1.) + BohrC2AngC = Bohr2Ang**3 + + unit = { + "Rho": 1.0 / BohrC2AngC, + "RhoInit": 1.0 / BohrC2AngC, + "RhoTot": 1.0 / BohrC2AngC, + "RhoDelta": 1.0 / BohrC2AngC, + "RhoXC": 1.0 / BohrC2AngC, + "RhoBader": 1.0 / BohrC2AngC, + "Chlocal": 1.0 / BohrC2AngC, + }.get(name, 1.0) if len(v[:].shape) == 3: grid.grid = v[:, :, :] * unit @@ -393,7 +396,7 @@ def read_grid(self, name, index=0, **kwargs): grid_reduce_indices(v, np.array(index) * unit, axis=0, out=grid.grid) try: - if v.unit == 'Ry': + if v.unit == "Ry": # Convert to ev grid *= Ry2eV except Exception: @@ -402,12 +405,12 @@ def read_grid(self, name, index=0, **kwargs): # Read the grid, we want the z-axis to be the fastest # looping direction, hence x,y,z == 0,1,2 - grid.grid = np.copy(np.swapaxes(grid.grid, 0, 2), order='C') + grid.grid = np.copy(np.swapaxes(grid.grid, 0, 2), order="C") return grid def write_basis(self, atom): - """ Write the current atoms orbitals as the basis + """Write the current atoms orbitals as the basis Parameters ---------- @@ -415,10 +418,10 @@ def write_basis(self, atom): atom specifications to write. """ sile_raise_write(self) - bs = self._crt_grp(self, 'BASIS') + bs = self._crt_grp(self, "BASIS") # Create variable of basis-indices - b = self._crt_var(bs, 'basis', 'i4', ('na_u',)) + b = self._crt_var(bs, "basis", "i4", ("na_u",)) b.info = "Basis of each atom by ID" for isp, (a, ia) in enumerate(atom.iter(True)): @@ -426,8 +429,10 @@ def write_basis(self, atom): if a.tag in bs.groups: # Assert the file sizes if bs.groups[a.tag].Number_of_orbitals != a.no: - raise ValueError(f'File {self.file} has erroneous data ' - 'in regards of the already stored dimensions.') + raise ValueError( + f"File {self.file} has erroneous data " + "in regards of the already stored dimensions." + ) else: ba = bs.createGroup(a.tag) ba.ID = np.int32(isp + 1) @@ -441,94 +446,107 @@ def write_basis(self, atom): ba.Number_of_orbitals = np.int32(a.no) def _write_settings(self): - """ Internal method for writing settings. + """Internal method for writing settings. Sadly the settings are not correct since we have no recollection of what created the matrices. So the values are just *some* values """ # Create the settings - st = self._crt_grp(self, 'SETTINGS') - v = self._crt_var(st, 'ElectronicTemperature', 'f8', ('one',)) + st = self._crt_grp(self, "SETTINGS") + v = self._crt_var(st, "ElectronicTemperature", "f8", ("one",)) v.info = "Electronic temperature used for smearing DOS" v.unit = "Ry" v[:] = 0.025 / Ry2eV - v = self._crt_var(st, 'BZ', 'i4', ('xyz', 'xyz')) + v = self._crt_var(st, "BZ", "i4", ("xyz", "xyz")) v.info = "Grid used for the Brillouin zone integration" v[:, :] = np.identity(3) * 2 - v = self._crt_var(st, 'BZ_displ', 'f8', ('xyz',)) + v = self._crt_var(st, "BZ_displ", "f8", ("xyz",)) v.info = "Monkhorst-Pack k-grid displacements" v.unit = "b**-1" - v[:] = 0. + v[:] = 0.0 def write_geometry(self, geometry): - """ Creates the NetCDF file and writes the geometry information """ + """Creates the NetCDF file and writes the geometry information""" sile_raise_write(self) # Create initial dimensions - self._crt_dim(self, 'one', 1) - self._crt_dim(self, 'n_s', np.prod(geometry.nsc, dtype=np.int32)) - self._crt_dim(self, 'xyz', 3) - self._crt_dim(self, 'no_s', np.prod(geometry.nsc, dtype=np.int32) * geometry.no) - self._crt_dim(self, 'no_u', geometry.no) - self._crt_dim(self, 'na_u', geometry.na) + self._crt_dim(self, "one", 1) + self._crt_dim(self, "n_s", np.prod(geometry.nsc, dtype=np.int32)) + self._crt_dim(self, "xyz", 3) + self._crt_dim(self, "no_s", np.prod(geometry.nsc, dtype=np.int32) * geometry.no) + self._crt_dim(self, "no_u", geometry.no) + self._crt_dim(self, "na_u", geometry.na) # Create initial geometry - v = self._crt_var(self, 'nsc', 'i4', ('xyz',)) - v.info = 'Number of supercells in each unit-cell direction' - v = self._crt_var(self, 'lasto', 'i4', ('na_u',)) - v.info = 'Last orbital of equivalent atom' - v = self._crt_var(self, 'xa', 'f8', ('na_u', 'xyz')) - v.info = 'Atomic coordinates' - v.unit = 'Bohr' - v = self._crt_var(self, 'cell', 'f8', ('xyz', 'xyz')) - v.info = 'Unit cell' - v.unit = 'Bohr' + v = self._crt_var(self, "nsc", "i4", ("xyz",)) + v.info = "Number of supercells in each unit-cell direction" + v = self._crt_var(self, "lasto", "i4", ("na_u",)) + v.info = "Last orbital of equivalent atom" + v = self._crt_var(self, "xa", "f8", ("na_u", "xyz")) + v.info = "Atomic coordinates" + v.unit = "Bohr" + v = self._crt_var(self, "cell", "f8", ("xyz", "xyz")) + v.info = "Unit cell" + v.unit = "Bohr" # Create designation of the creation - self.method = 'sisl' + self.method = "sisl" # Save stuff - self.variables['nsc'][:] = geometry.nsc - self.variables['xa'][:] = geometry.xyz / Bohr2Ang - self.variables['cell'][:] = geometry.cell / Bohr2Ang + self.variables["nsc"][:] = geometry.nsc + self.variables["xa"][:] = geometry.xyz / Bohr2Ang + self.variables["cell"][:] = geometry.cell / Bohr2Ang # Create basis group self.write_basis(geometry.atoms) # Store the lasto variable as the remaining thing to do - self.variables['lasto'][:] = geometry.lasto + 1 + self.variables["lasto"][:] = geometry.lasto + 1 def _write_sparsity(self, csr, nsc): if csr.nnz != len(csr.col): - raise ValueError(f"{self.file}._write_sparsity *must* be a finalized sparsity matrix") + raise ValueError( + f"{self.file}._write_sparsity *must* be a finalized sparsity matrix" + ) # Create sparse group - sp = self._crt_grp(self, 'SPARSE') - - if 'n_col' in sp.variables: - if len(sp.dimensions['nnzs']) != csr.nnz or \ - np.any(sp.variables['n_col'][:] != csr.ncol[:]) or \ - np.any(sp.variables['list_col'][:] != csr.col[:]+1) or \ - np.any(sp.variables['isc_off'][:] != _siesta.siesta_sc_off(*nsc).T): - raise ValueError(f"{self.file} sparsity pattern stored *MUST* be equivalent for all matrices") + sp = self._crt_grp(self, "SPARSE") + + if "n_col" in sp.variables: + if ( + len(sp.dimensions["nnzs"]) != csr.nnz + or np.any(sp.variables["n_col"][:] != csr.ncol[:]) + or np.any(sp.variables["list_col"][:] != csr.col[:] + 1) + or np.any(sp.variables["isc_off"][:] != _siesta.siesta_sc_off(*nsc).T) + ): + raise ValueError( + f"{self.file} sparsity pattern stored *MUST* be equivalent for all matrices" + ) else: - self._crt_dim(sp, 'nnzs', csr.col.shape[0]) - v = self._crt_var(sp, 'n_col', 'i4', ('no_u',)) + self._crt_dim(sp, "nnzs", csr.col.shape[0]) + v = self._crt_var(sp, "n_col", "i4", ("no_u",)) v.info = "Number of non-zero elements per row" v[:] = csr.ncol[:] - v = self._crt_var(sp, 'list_col', 'i4', ('nnzs',), - chunksizes=(len(csr.col),), **self._cmp_args) + v = self._crt_var( + sp, + "list_col", + "i4", + ("nnzs",), + chunksizes=(len(csr.col),), + **self._cmp_args, + ) v.info = "Supercell column indices in the sparse format" v[:] = csr.col[:] + 1 # correct for fortran indices - v = self._crt_var(sp, 'isc_off', 'i4', ('n_s', 'xyz')) + v = self._crt_var(sp, "isc_off", "i4", ("n_s", "xyz")) v.info = "Index of supercell coordinates" v[:, :] = _siesta.siesta_sc_off(*nsc).T return sp def _write_overlap(self, spgroup, csr, orthogonal, S_idx): - v = self._crt_var(spgroup, 'S', 'f8', ('nnzs',), - chunksizes=(len(csr.col),), **self._cmp_args) + v = self._crt_var( + spgroup, "S", "f8", ("nnzs",), chunksizes=(len(csr.col),), **self._cmp_args + ) v.info = "Overlap matrix" if orthogonal: # We need to create the orthogonal pattern @@ -555,21 +573,25 @@ def _write_overlap(self, spgroup, csr, orthogonal, S_idx): row = row[diag_idx] idx = (np.diff(row) != 1).nonzero()[0] row = row[idx] + 1 - raise ValueError(f'{self}._write_overlap ' - 'is trying to write an Overlap in Siesta format with ' - f'missing diagonal terms on rows {row}. Please explicitly add *all* diagonal overlap terms.') + raise ValueError( + f"{self}._write_overlap " + "is trying to write an Overlap in Siesta format with " + f"missing diagonal terms on rows {row}. Please explicitly add *all* diagonal overlap terms." + ) - D[idx[diag_idx]] = 1. + D[idx[diag_idx]] = 1.0 v[:] = tmp._D[:, 0] del tmp else: v[:] = csr._D[:, S_idx] def write_overlap(self, S, **kwargs): - """ Write the overlap matrix to the NetCDF file """ + """Write the overlap matrix to the NetCDF file""" csr = S.transpose(sort=False)._csr if csr.nnz == 0: - raise SileError(f'{self}.write_overlap cannot write a zero element sparse matrix!') + raise SileError( + f"{self}.write_overlap cannot write a zero element sparse matrix!" + ) # Convert to siesta CSR _csr_to_siesta(S.geometry, csr) @@ -584,7 +606,7 @@ def write_overlap(self, S, **kwargs): self._write_overlap(spgroup, csr, S.orthogonal, S.S_idx) def write_hamiltonian(self, H, **kwargs): - """ Writes Hamiltonian model to file + """Writes Hamiltonian model to file Parameters ---------- @@ -595,7 +617,9 @@ def write_hamiltonian(self, H, **kwargs): """ csr = H.transpose(spin=False, sort=False)._csr if csr.nnz == 0: - raise SileError(f'{self}.write_hamiltonian cannot write a zero element sparse matrix!') + raise SileError( + f"{self}.write_hamiltonian cannot write a zero element sparse matrix!" + ) # Convert to siesta CSR _csr_to_siesta(H.geometry, csr) @@ -605,18 +629,20 @@ def write_hamiltonian(self, H, **kwargs): # Ensure that the geometry is written self.write_geometry(H.geometry) - self._crt_dim(self, 'spin', len(H.spin)) + self._crt_dim(self, "spin", len(H.spin)) - if H.dkind != 'f': - raise NotImplementedError('Currently we only allow writing a floating point Hamiltonian to the Siesta format') + if H.dkind != "f": + raise NotImplementedError( + "Currently we only allow writing a floating point Hamiltonian to the Siesta format" + ) - v = self._crt_var(self, 'Ef', 'f8', ('one',)) - v.info = 'Fermi level' - v.unit = 'Ry' - v[:] = kwargs.get('Ef', 0.) / Ry2eV - v = self._crt_var(self, 'Qtot', 'f8', ('one',)) - v.info = 'Total charge' - v[0] = kwargs.get('Q', kwargs.get('Qtot', H.geometry.q0)) + v = self._crt_var(self, "Ef", "f8", ("one",)) + v.info = "Fermi level" + v.unit = "Ry" + v[:] = kwargs.get("Ef", 0.0) / Ry2eV + v = self._crt_var(self, "Qtot", "f8", ("one",)) + v.info = "Total charge" + v[0] = kwargs.get("Q", kwargs.get("Qtot", H.geometry.q0)) # Append the sparsity pattern spgroup = self._write_sparsity(csr, H.geometry.nsc) @@ -624,8 +650,14 @@ def write_hamiltonian(self, H, **kwargs): # Save sparse matrices self._write_overlap(spgroup, csr, H.orthogonal, H.S_idx) - v = self._crt_var(spgroup, 'H', 'f8', ('spin', 'nnzs'), - chunksizes=(1, len(csr.col)), **self._cmp_args) + v = self._crt_var( + spgroup, + "H", + "f8", + ("spin", "nnzs"), + chunksizes=(1, len(csr.col)), + **self._cmp_args, + ) v.info = "Hamiltonian" v.unit = "Ry" for i in range(len(H.spin)): @@ -634,7 +666,7 @@ def write_hamiltonian(self, H, **kwargs): self._write_settings() def write_density_matrix(self, DM, **kwargs): - """ Writes density matrix model to file + """Writes density matrix model to file Parameters ---------- @@ -643,7 +675,9 @@ def write_density_matrix(self, DM, **kwargs): """ csr = DM.transpose(spin=False, sort=False)._csr if csr.nnz == 0: - raise SileError(f'{self}.write_density_matrix cannot write a zero element sparse matrix!') + raise SileError( + f"{self}.write_density_matrix cannot write a zero element sparse matrix!" + ) # Convert to siesta CSR (we don't need to sort this matrix) _csr_to_siesta(DM.geometry, csr) @@ -653,18 +687,20 @@ def write_density_matrix(self, DM, **kwargs): # Ensure that the geometry is written self.write_geometry(DM.geometry) - self._crt_dim(self, 'spin', len(DM.spin)) + self._crt_dim(self, "spin", len(DM.spin)) - if DM.dkind != 'f': - raise NotImplementedError('Currently we only allow writing a floating point density matrix to the Siesta format') + if DM.dkind != "f": + raise NotImplementedError( + "Currently we only allow writing a floating point density matrix to the Siesta format" + ) - v = self._crt_var(self, 'Qtot', 'f8', ('one',)) - v.info = 'Total charge' + v = self._crt_var(self, "Qtot", "f8", ("one",)) + v.info = "Total charge" v[:] = np.sum(DM.geometry.atoms.q0) - if 'Qtot' in kwargs: - v[:] = kwargs['Qtot'] - if 'Q' in kwargs: - v[:] = kwargs['Q'] + if "Qtot" in kwargs: + v[:] = kwargs["Qtot"] + if "Q" in kwargs: + v[:] = kwargs["Q"] # Append the sparsity pattern spgroup = self._write_sparsity(csr, DM.geometry.nsc) @@ -672,8 +708,14 @@ def write_density_matrix(self, DM, **kwargs): # Save sparse matrices self._write_overlap(spgroup, csr, DM.orthogonal, DM.S_idx) - v = self._crt_var(spgroup, 'DM', 'f8', ('spin', 'nnzs'), - chunksizes=(1, len(csr.col)), **self._cmp_args) + v = self._crt_var( + spgroup, + "DM", + "f8", + ("spin", "nnzs"), + chunksizes=(1, len(csr.col)), + **self._cmp_args, + ) v.info = "Density matrix" for i in range(len(DM.spin)): v[i, :] = csr._D[:, i] @@ -681,7 +723,7 @@ def write_density_matrix(self, DM, **kwargs): self._write_settings() def write_energy_density_matrix(self, EDM, **kwargs): - """ Writes energy density matrix model to file + """Writes energy density matrix model to file Parameters ---------- @@ -690,7 +732,9 @@ def write_energy_density_matrix(self, EDM, **kwargs): """ csr = EDM.transpose(spin=False, sort=False)._csr if csr.nnz == 0: - raise SileError(f'{self}.write_energy_density_matrix cannot write a zero element sparse matrix!') + raise SileError( + f"{self}.write_energy_density_matrix cannot write a zero element sparse matrix!" + ) # no need to sort this matrix _csr_to_siesta(EDM.geometry, csr) @@ -700,22 +744,24 @@ def write_energy_density_matrix(self, EDM, **kwargs): # Ensure that the geometry is written self.write_geometry(EDM.geometry) - self._crt_dim(self, 'spin', len(EDM.spin)) + self._crt_dim(self, "spin", len(EDM.spin)) - if EDM.dkind != 'f': - raise NotImplementedError('Currently we only allow writing a floating point density matrix to the Siesta format') + if EDM.dkind != "f": + raise NotImplementedError( + "Currently we only allow writing a floating point density matrix to the Siesta format" + ) - v = self._crt_var(self, 'Ef', 'f8', ('one',)) - v.info = 'Fermi level' - v.unit = 'Ry' - v[:] = kwargs.get('Ef', 0.) / Ry2eV - v = self._crt_var(self, 'Qtot', 'f8', ('one',)) - v.info = 'Total charge' + v = self._crt_var(self, "Ef", "f8", ("one",)) + v.info = "Fermi level" + v.unit = "Ry" + v[:] = kwargs.get("Ef", 0.0) / Ry2eV + v = self._crt_var(self, "Qtot", "f8", ("one",)) + v.info = "Total charge" v[:] = np.sum(EDM.geometry.atoms.q0) - if 'Qtot' in kwargs: - v[:] = kwargs['Qtot'] - if 'Q' in kwargs: - v[:] = kwargs['Q'] + if "Qtot" in kwargs: + v[:] = kwargs["Qtot"] + if "Q" in kwargs: + v[:] = kwargs["Q"] # Append the sparsity pattern spgroup = self._write_sparsity(csr, EDM.geometry.nsc) @@ -723,8 +769,14 @@ def write_energy_density_matrix(self, EDM, **kwargs): # Save sparse matrices self._write_overlap(spgroup, csr, EDM.orthogonal, EDM.S_idx) - v = self._crt_var(spgroup, 'EDM', 'f8', ('spin', 'nnzs'), - chunksizes=(1, len(csr.col)), **self._cmp_args) + v = self._crt_var( + spgroup, + "EDM", + "f8", + ("spin", "nnzs"), + chunksizes=(1, len(csr.col)), + **self._cmp_args, + ) v.info = "Energy density matrix" v.unit = "Ry" for i in range(len(EDM.spin)): @@ -733,7 +785,7 @@ def write_energy_density_matrix(self, EDM, **kwargs): self._write_settings() def write_dynamical_matrix(self, D, **kwargs): - """ Writes dynamical matrix model to file + """Writes dynamical matrix model to file Parameters ---------- @@ -742,7 +794,9 @@ def write_dynamical_matrix(self, D, **kwargs): """ csr = D.transpose(sort=False)._csr if csr.nnz == 0: - raise SileError(f'{self}.write_dynamical_matrix cannot write a zero element sparse matrix!') + raise SileError( + f"{self}.write_dynamical_matrix cannot write a zero element sparse matrix!" + ) # Convert to siesta CSR _csr_to_siesta(D.geometry, csr) @@ -751,19 +805,21 @@ def write_dynamical_matrix(self, D, **kwargs): # Ensure that the geometry is written self.write_geometry(D.geometry) - self._crt_dim(self, 'spin', 1) + self._crt_dim(self, "spin", 1) - if D.dkind != 'f': - raise NotImplementedError('Currently we only allow writing a floating point dynamical matrix to the Siesta format') + if D.dkind != "f": + raise NotImplementedError( + "Currently we only allow writing a floating point dynamical matrix to the Siesta format" + ) - v = self._crt_var(self, 'Ef', 'f8', ('one',)) - v.info = 'Fermi level' - v.unit = 'Ry' - v[:] = 0. - v = self._crt_var(self, 'Qtot', 'f8', ('one',)) - v.info = 'Total charge' - v.unit = 'e' - v[:] = 0. + v = self._crt_var(self, "Ef", "f8", ("one",)) + v.info = "Fermi level" + v.unit = "Ry" + v[:] = 0.0 + v = self._crt_var(self, "Qtot", "f8", ("one",)) + v.info = "Total charge" + v.unit = "e" + v[:] = 0.0 # Append the sparsity pattern spgroup = self._write_sparsity(csr, D.geometry.nsc) @@ -771,19 +827,25 @@ def write_dynamical_matrix(self, D, **kwargs): # Save sparse matrices self._write_overlap(spgroup, csr, D.orthogonal, D.S_idx) - v = self._crt_var(spgroup, 'H', 'f8', ('spin', 'nnzs'), - chunksizes=(1, len(csr.col)), **self._cmp_args) + v = self._crt_var( + spgroup, + "H", + "f8", + ("spin", "nnzs"), + chunksizes=(1, len(csr.col)), + **self._cmp_args, + ) v.info = "Dynamical matrix" v.unit = "Ry**2" - v[0, :] = csr._D[:, 0] / Ry2eV ** 2 + v[0, :] = csr._D[:, 0] / Ry2eV**2 self._write_settings() def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('nc', ncSileSiesta) +add_sile("nc", ncSileSiesta) diff --git a/src/sisl/io/siesta/sile.py b/src/sisl/io/siesta/sile.py index ab844fb5dd..2563a67f4f 100644 --- a/src/sisl/io/siesta/sile.py +++ b/src/sisl/io/siesta/sile.py @@ -4,6 +4,7 @@ try: from . import _siesta + has_fortran_module = True except ImportError: has_fortran_module = False @@ -12,7 +13,7 @@ from ..sile import Sile, SileBin, SileCDF, SileError -__all__ = ['SileSiesta', 'SileCDFSiesta', 'SileBinSiesta'] +__all__ = ["SileSiesta", "SileCDFSiesta", "SileBinSiesta"] @set_module("sisl.io.siesta") @@ -22,7 +23,6 @@ class SileSiesta(Sile): @set_module("sisl.io.siesta") class SileCDFSiesta(SileCDF): - # all netcdf output should not be masked def _setup(self, *args, **kwargs): super()._setup(*args, **kwargs) @@ -35,7 +35,6 @@ def _setup(self, *args, **kwargs): @set_module("sisl.io.siesta") class SileBinSiesta(SileBin): - def _setup(self, *args, **kwargs): """We set up everything to handle the fortran I/O unit""" super()._setup(*args, **kwargs) @@ -43,9 +42,9 @@ def _setup(self, *args, **kwargs): def _fortran_check(self, method, message, ret_msg=False): ierr = _siesta.io_m.iostat_query() - msg = '' + msg = "" if ierr != 0: - msg = f'{self!s}.{method} {message} (ierr={ierr})' + msg = f"{self!s}.{method} {message} (ierr={ierr})" if not ret_msg: raise SileError(msg) if ret_msg: @@ -62,13 +61,18 @@ def _fortran_open(self, mode, rewind=False): # retain indices return else: - if mode == 'r': + if mode == "r": self._iu = _siesta.io_m.open_file_read(self.file) - elif mode == 'w': + elif mode == "w": self._iu = _siesta.io_m.open_file_write(self.file) else: - raise SileError(f"Mode '{mode}' is not an accepted mode to open a fortran file unit. Use only 'r' or 'w'") - self._fortran_check('_fortran_open', 'could not open for {}.'.format({'r': 'reading', 'w': 'writing'}[mode])) + raise SileError( + f"Mode '{mode}' is not an accepted mode to open a fortran file unit. Use only 'r' or 'w'" + ) + self._fortran_check( + "_fortran_open", + "could not open for {}.".format({"r": "reading", "w": "writing"}[mode]), + ) def _fortran_close(self): if not self._fortran_is_open(): diff --git a/src/sisl/io/siesta/stdout.py b/src/sisl/io/siesta/stdout.py index 64ffe25d93..fe55b9f60e 100644 --- a/src/sisl/io/siesta/stdout.py +++ b/src/sisl/io/siesta/stdout.py @@ -21,12 +21,12 @@ __all__ = ["stdoutSileSiesta", "outSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') +Bohr2Ang = unit_convert("Bohr", "Ang") _A = SileSiesta.InfoAttr def _ensure_atoms(atoms): - """ Ensures that the atoms list is a list with entries (converts `None` to a list). """ + """Ensures that the atoms list is a list with entries (converts `None` to a list).""" if atoms is None: return [Atom(i) for i in range(150)] elif len(atoms) == 0: @@ -36,25 +36,32 @@ def _ensure_atoms(atoms): @set_module("sisl.io.siesta") class stdoutSileSiesta(SileSiesta): - """ Output file from Siesta + """Output file from Siesta This enables reading the output quantities from the Siesta output. """ - + _info_attributes_ = [ - _A("completed", r".*Job completed", - lambda attr, match: lambda : True, default=lambda : False), + _A( + "completed", + r".*Job completed", + lambda attr, match: lambda: True, + default=lambda: False, + ), ] - @deprecation("stdoutSileSiesta.completed is deprecated in favor of stdoutSileSiesta.info.completed", "0.16.0") + @deprecation( + "stdoutSileSiesta.completed is deprecated in favor of stdoutSileSiesta.info.completed", + "0.16.0", + ) def completed(self): - """ True if the full file has been read and "Job completed" was found. """ + """True if the full file has been read and "Job completed" was found.""" return self.info.completed() @lru_cache(1) @sile_fh_open(True) def read_basis(self): - """ Reads the basis as found in the output file + """Reads the basis as found in the output file This parses 3 things: @@ -69,17 +76,17 @@ def read_basis(self): atoms = {} order = [] - while 'Species number:' in line: + while "Species number:" in line: ls = line.split() - if ls[3] == 'Atomic': - atoms[ls[7]] = {'Z': int(ls[5]), 'tag': ls[7]} + if ls[3] == "Atomic": + atoms[ls[7]] = {"Z": int(ls[5]), "tag": ls[7]} order.append(ls[7]) else: - atoms[ls[4]] = {'Z': int(ls[7]), 'tag': ls[4]} + atoms[ls[4]] = {"Z": int(ls[7]), "tag": ls[4]} order.append(ls[4]) line = self.readline() - # Now go down to basis_specs + # Now go down to basis_specs found, line = self.step_to("") while found: # ===== @@ -98,6 +105,7 @@ def read_basis(self): line = self.readline() from .fdf import fdfSileSiesta + atom_orbs = fdfSileSiesta._parse_pao_basis(block) for atom, orbs in atom_orbs.items(): atoms[atom]["orbitals"] = orbs @@ -105,14 +113,16 @@ def read_basis(self): return [Atom(**atoms[tag]) for tag in order] def _r_lattice_outcell(self): - """ Wrapper for reading the unit-cell from the outcoor block """ + """Wrapper for reading the unit-cell from the outcoor block""" # Read until outcell is found found, line = self.step_to("outcell: Unit cell vectors") if not found: - raise ValueError(f"{self.__class__.__name__}._r_lattice_outcell did not find outcell key") + raise ValueError( + f"{self.__class__.__name__}._r_lattice_outcell did not find outcell key" + ) - Ang = 'Ang' in line + Ang = "Ang" in line # We read the unit-cell vectors (in Ang) cell = [] @@ -130,14 +140,14 @@ def _r_lattice_outcell(self): return Lattice(cell) def _r_geometry_outcoor(self, line, atoms=None): - """ Wrapper for reading the geometry as in the outcoor output """ + """Wrapper for reading the geometry as in the outcoor output""" atoms_order = _ensure_atoms(atoms) - is_final = 'Relaxed' in line or 'Final (unrelaxed)' in line + is_final = "Relaxed" in line or "Final (unrelaxed)" in line # Now we have outcoor - scaled = 'scaled' in line - fractional = 'fractional' in line - Ang = 'Ang' in line + scaled = "scaled" in line + fractional = "fractional" in line + Ang = "Ang" in line # Read in data xyz = [] @@ -167,7 +177,9 @@ def _r_geometry_outcoor(self, line, atoms=None): # The output file for siesta does not # contain the lattice constant. # So... :( - raise ValueError("Could not read the lattice-constant for the scaled geometry") + raise ValueError( + "Could not read the lattice-constant for the scaled geometry" + ) elif fractional: xyz = xyz.dot(cell.cell) elif not Ang: @@ -176,11 +188,11 @@ def _r_geometry_outcoor(self, line, atoms=None): return Geometry(xyz, atoms, lattice=cell) def _r_geometry_atomic(self, line, atoms=None): - """ Wrapper for reading the geometry as in the outcoor output """ + """Wrapper for reading the geometry as in the outcoor output""" atoms_order = _ensure_atoms(atoms) # Now we have outcoor - Ang = 'Ang' in line + Ang = "Ang" in line # Read in data xyz = [] @@ -189,7 +201,7 @@ def _r_geometry_atomic(self, line, atoms=None): while len(line.strip()) > 0: line = line.split() xyz.append([float(x) for x in line[1:4]]) - atoms.append(atoms_order[int(line[4])-1]) + atoms.append(atoms_order[int(line[4]) - 1]) line = self.readline() # Retrieve the unit-cell (but do not skip file-descriptor position) @@ -207,7 +219,7 @@ def _r_geometry_atomic(self, line, atoms=None): @sile_fh_open() def read_geometry(self, last=True, all=False): - """ Reads the geometry from the Siesta output file + """Reads the geometry from the Siesta output file Parameters ---------- @@ -229,17 +241,17 @@ def read_geometry(self, last=True, all=False): last = False def func_none(*args, **kwargs): - """ Wrapper to return None """ + """Wrapper to return None""" return None def next_geom(): coord, func = 0, func_none - line = ' ' - while coord == 0 and line != '': + line = " " + while coord == 0 and line != "": line = self.readline() - if 'outcoor' in line and 'coordinates' in line: + if "outcoor" in line and "coordinates" in line: coord, func = 1, self._r_geometry_outcoor - elif 'siesta: Atomic coordinates' in line: + elif "siesta: Atomic coordinates" in line: coord, func = 2, self._r_geometry_atomic return coord, func(line, atoms) @@ -271,7 +283,7 @@ def next_geom(): @sile_fh_open(True) def read_force(self, last=True, all=False, total=False, max=False, key="siesta"): - """ Reads the forces from the Siesta output file + """Reads the forces from the Siesta output file Parameters ---------- @@ -291,7 +303,7 @@ def read_force(self, last=True, all=False, total=False, max=False, key="siesta") Note that this is not the same as doing `max(outSile.read_force(total=True))` since the forces returned in that case are averages on each axis. key: {"siesta", "ts"} - Specifies the indicator string for the forces that are to be read. + Specifies the indicator string for the forces that are to be read. The function will look for a line containing ``f'{key}: Atomic forces'`` to start reading forces. @@ -316,7 +328,6 @@ def read_force(self, last=True, all=False, total=False, max=False, key="siesta") # Read until forces are found def next_force(): - found, line = self.step_to(f"{key}: Atomic forces", allow_reread=False) if not found: return None @@ -324,18 +335,18 @@ def next_force(): # Now read data F = [] line = self.readline() - if 'siesta:' in line: + if "siesta:" in line: # This is the final summary, we don't need to read it as it does not contain new information # and also it make break things since max forces are not written there return None # First, we encounter the atomic forces - while '---' not in line: + while "---" not in line: line = line.split() if not (total or max): F.append([float(x) for x in line[-3:]]) line = self.readline() - if line == '': + if line == "": break line = self.readline() @@ -344,7 +355,7 @@ def next_force(): F = [float(x) for x in line.split()[-3:]] line = self.readline() - #And after that we can read the max force + # And after that we can read the max force if max and len(line.split()) != 0: line = self.readline() maxF = float(line.split()[1]) @@ -360,7 +371,8 @@ def next_force(): def return_forces(Fs): # Handle cases where we can't now if they are found - if Fs is None: return None + if Fs is None: + return None Fs = _a.arrayd(Fs) if max and total: return (Fs[..., :-1], Fs[..., -1]) @@ -386,7 +398,7 @@ def return_forces(Fs): @sile_fh_open(True) def read_stress(self, key="static", last=True, all=False): - """ Reads the stresses from the Siesta output file + """Reads the stresses from the Siesta output file Parameters ---------- @@ -409,10 +421,9 @@ def read_stress(self, key="static", last=True, all=False): # Read until stress are found def next_stress(): - found, line = self.step_to(f"siesta: Stress tensor", allow_reread=False) found = found and key in line - while not found and line != '': + while not found and line != "": found, line = self.step_to(f"siesta: Stress tensor", allow_reread=False) found = found and key in line if not found: @@ -444,8 +455,8 @@ def next_stress(): return next_stress() @sile_fh_open(True) - def read_moment(self, orbitals=False, quantity='S', last=True, all=False): - """ Reads the moments from the Siesta output file + def read_moment(self, orbitals=False, quantity="S", last=True, all=False): + """Reads the moments from the Siesta output file These will only be present in case of spin-orbit coupling. @@ -469,26 +480,25 @@ def read_moment(self, orbitals=False, quantity='S', last=True, all=False): # The moments are printed in SPECIES list itt = iter(self) - next(itt) # empty - next(itt) # empty + next(itt) # empty + next(itt) # empty na = 0 # Loop the species tbl = [] # Read the species label while True: - next(itt) # "" - next(itt) # Atom Orb ... + next(itt) # "" + next(itt) # Atom Orb ... # Loop atoms in this species list while True: line = next(itt) - if line.startswith('Species') or \ - line.startswith('--'): + if line.startswith("Species") or line.startswith("--"): break - line = ' ' + line = " " atom = [] ia = 0 - while not line.startswith('--'): + while not line.startswith("--"): line = next(itt).split() if ia == 0: ia = int(line[0]) @@ -496,19 +506,19 @@ def read_moment(self, orbitals=False, quantity='S', last=True, all=False): raise ValueError("Error in moments formatting.") # Track maximum number of atoms na = max(ia, na) - if quantity == 'S': + if quantity == "S": atom.append([float(x) for x in line[4:7]]) - elif quantity == 'L': + elif quantity == "L": atom.append([float(x) for x in line[7:10]]) - line = next(itt).split() # Total ... + line = next(itt).split() # Total ... if not orbitals: ia = int(line[0]) - if quantity == 'S': + if quantity == "S": atom.append([float(x) for x in line[4:7]]) - elif quantity == 'L': + elif quantity == "L": atom.append([float(x) for x in line[8:11]]) tbl.append((ia, atom)) - if line.startswith('--'): + if line.startswith("--"): break # Sort according to the atomic index @@ -516,7 +526,7 @@ def read_moment(self, orbitals=False, quantity='S', last=True, all=False): # Insert in the correct atomic for ia, atom in tbl: - moments[ia-1] = atom + moments[ia - 1] = atom if not all: return _a.arrayd(moments) @@ -524,7 +534,7 @@ def read_moment(self, orbitals=False, quantity='S', last=True, all=False): @sile_fh_open(True) def read_energy(self): - """ Reads the final energy distribution + """Reads the final energy distribution Currently the energies translated are: @@ -606,7 +616,7 @@ def read_energy(self): "Fermi": "fermi", "Enegf": "negf", "(Free)E+ p_basis*V_orbitals": "basis.enthalpy", - "(Free)E + p_basis*V_orbitals": "basis.enthalpy", # we may correct the missing space + "(Free)E + p_basis*V_orbitals": "basis.enthalpy", # we may correct the missing space } def assign(out, key, val): @@ -614,7 +624,9 @@ def assign(out, key, val): try: val = float(val) except ValueError: - warn(f"Could not convert energy '{key}' ({val}) to a float, assigning nan.") + warn( + f"Could not convert energy '{key}' ({val}) to a float, assigning nan." + ) val = np.nan if "." in key: @@ -641,7 +653,7 @@ def assign(out, key, val): return out def read_data(self, *args, **kwargs): - """ Read specific content in the Siesta out file + """Read specific content in the Siesta out file The currently implemented things are denoted in the parameters list. @@ -689,8 +701,10 @@ def read_data(self, *args, **kwargs): return val @sile_fh_open(True) - def read_scf(self, key="scf", iscf=-1, imd=None, as_dataframe=False, ret_header=False): - r""" Parse SCF information and return a table of SCF information depending on what is requested + def read_scf( + self, key="scf", iscf=-1, imd=None, as_dataframe=False, ret_header=False + ): + r"""Parse SCF information and return a table of SCF information depending on what is requested Parameters ---------- @@ -706,53 +720,60 @@ def read_scf(self, key="scf", iscf=-1, imd=None, as_dataframe=False, ret_header= as_dataframe: boolean, optional whether the information should be returned as a `pandas.DataFrame`. The advantage of this format is that everything is indexed and therefore you know what each value means.You can also - perform operations very easily on a dataframe. + perform operations very easily on a dataframe. ret_header: bool, optional whether to also return the headers that define each value in the returned array, will have no effect if `as_dataframe` is true. """ - #These are the properties that are written in SIESTA scf + # These are the properties that are written in SIESTA scf props = ["iscf", "Eharris", "E_KS", "FreeEng", "dDmax", "Ef", "dHmax"] if not iscf is None: if iscf == 0: - raise ValueError(f"{self.__class__.__name__}.read_scf requires iscf argument to *not* be 0!") + raise ValueError( + f"{self.__class__.__name__}.read_scf requires iscf argument to *not* be 0!" + ) if not imd is None: if imd == 0: - raise ValueError(f"{self.__class__.__name__}.read_scf requires imd argument to *not* be 0!") + raise ValueError( + f"{self.__class__.__name__}.read_scf requires imd argument to *not* be 0!" + ) + def reset_d(d, line): - if line.startswith('SCF cycle converged') or line.startswith('SCF_NOT_CONV'): - if len(d['data']) > 0: - d['_final_iscf'] = 1 - elif line.startswith('SCF cycle continued'): - d['_final_iscf'] = 0 + if line.startswith("SCF cycle converged") or line.startswith( + "SCF_NOT_CONV" + ): + if len(d["data"]) > 0: + d["_final_iscf"] = 1 + elif line.startswith("SCF cycle continued"): + d["_final_iscf"] = 0 def common_parse(line, d): nonlocal props - if line.startswith('ts-Vha:'): - d['ts-Vha'] = [float(line.split()[1])] - if 'ts-Vha' not in props: - d['order'].append("ts-Vha") + if line.startswith("ts-Vha:"): + d["ts-Vha"] = [float(line.split()[1])] + if "ts-Vha" not in props: + d["order"].append("ts-Vha") props.append("ts-Vha") elif line.startswith("spin moment: S"): # 4.1 and earlier - d['S'] = list(map(float, line.split("=")[1].split()[1:])) - if 'Sx' not in props: - d['order'].append("S") - props.extend(['Sx', 'Sy', 'Sz']) + d["S"] = list(map(float, line.split("=")[1].split()[1:])) + if "Sx" not in props: + d["order"].append("S") + props.extend(["Sx", "Sy", "Sz"]) elif line.startswith("spin moment: {S}"): # 4.2 and later - d['S'] = list(map(float, line.split("= {")[1].split()[:3])) - if 'Sx' not in props: - d['order'].append("S") - props.extend(['Sx', 'Sy', 'Sz']) - elif line.startswith('bulk-bias: |v'): + d["S"] = list(map(float, line.split("= {")[1].split()[:3])) + if "Sx" not in props: + d["order"].append("S") + props.extend(["Sx", "Sy", "Sz"]) + elif line.startswith("bulk-bias: |v"): # TODO old version should be removed once released - d['bb-v'] = list(map(float, line.split()[-3:])) - if 'BB-vx' not in props: - d['order'].append("bb-v") - props.extend(['BB-vx', 'BB-vy', 'BB-vz']) + d["bb-v"] = list(map(float, line.split()[-3:])) + if "BB-vx" not in props: + d["order"].append("bb-v") + props.extend(["BB-vx", "BB-vy", "BB-vz"]) elif line.startswith("bulk-bias: {v}"): idx = line.index("{v}") if line[idx + 3] == "_": @@ -768,7 +789,7 @@ def common_parse(line, d): d["order"].append(lbl) props.extend([f"{lbl}-vx", f"{lbl}-vy", f"{lbl}-vz"]) elif line.startswith("bulk-bias: dq"): - d['BB-q'] = list(map(float, line.split()[-2:])) + d["BB-q"] = list(map(float, line.split()[-2:])) if "BB-dq" not in props: d["order"].append("BB-q") props.extend(["BB-dq", "BB-q0"]) @@ -776,56 +797,84 @@ def common_parse(line, d): return False return True - if key.lower() == 'scf': + if key.lower() == "scf": + def parse_next(line, d): - line = line.strip().replace('*', '0') + line = line.strip().replace("*", "0") reset_d(d, line) if common_parse(line, d): pass - elif line.startswith('scf:'): - d['_found_iscf'] = True + elif line.startswith("scf:"): + d["_found_iscf"] = True if len(line) == 97: # this should be for Efup/dwn # but I think this will fail for as_dataframe (TODO) - data = [int(line[5:9]), float(line[9:25]), float(line[25:41]), - float(line[41:57]), float(line[57:67]), float(line[67:77]), - float(line[77:87]), float(line[87:97])] + data = [ + int(line[5:9]), + float(line[9:25]), + float(line[25:41]), + float(line[41:57]), + float(line[57:67]), + float(line[67:77]), + float(line[77:87]), + float(line[87:97]), + ] elif len(line) == 87: - data = [int(line[5:9]), float(line[9:25]), float(line[25:41]), - float(line[41:57]), float(line[57:67]), float(line[67:77]), - float(line[77:87])] + data = [ + int(line[5:9]), + float(line[9:25]), + float(line[25:41]), + float(line[41:57]), + float(line[57:67]), + float(line[67:77]), + float(line[77:87]), + ] else: # Populate DATA by splitting data = line.split() - data = [int(data[1])] + list(map(float, data[2:])) + data = [int(data[1])] + list(map(float, data[2:])) construct_data(d, data) - elif key.lower() == 'ts-scf': + elif key.lower() == "ts-scf": + def parse_next(line, d): - line = line.strip().replace('*', '0') + line = line.strip().replace("*", "0") reset_d(d, line) if common_parse(line, d): pass - elif line.startswith('ts-q:'): + elif line.startswith("ts-q:"): data = line.split()[1:] try: - d['ts-q'] = list(map(float, data)) + d["ts-q"] = list(map(float, data)) except Exception: # We are probably reading a device list # ensure that props are appended if data[-1] not in props: d["order"].append("ts-q") props.extend(data) - elif line.startswith('ts-scf:'): - d['_found_iscf'] = True + elif line.startswith("ts-scf:"): + d["_found_iscf"] = True if len(line) == 100: - data = [int(line[8:12]), float(line[12:28]), float(line[28:44]), - float(line[44:60]), float(line[60:70]), float(line[70:80]), - float(line[80:90]), float(line[90:100])] + data = [ + int(line[8:12]), + float(line[12:28]), + float(line[28:44]), + float(line[44:60]), + float(line[60:70]), + float(line[70:80]), + float(line[80:90]), + float(line[90:100]), + ] elif len(line) == 90: - data = [int(line[8:12]), float(line[12:28]), float(line[28:44]), - float(line[44:60]), float(line[60:70]), float(line[70:80]), - float(line[80:90])] + data = [ + int(line[8:12]), + float(line[12:28]), + float(line[28:44]), + float(line[44:60]), + float(line[60:70]), + float(line[70:80]), + float(line[80:90]), + ] else: # Populate DATA by splitting data = line.split() @@ -834,11 +883,12 @@ def parse_next(line, d): # A temporary dictionary to hold information while reading the output file d = { - '_found_iscf': False, - '_final_iscf': 0, - 'data': [], - 'order': [], + "_found_iscf": False, + "_final_iscf": 0, + "data": [], + "order": [], } + def construct_data(d, data): for key in d["order"]: data.extend(d[key]) @@ -848,9 +898,9 @@ def construct_data(d, data): scf = [] for line in self: parse_next(line, d) - if d['_found_iscf']: - d['_found_iscf'] = False - data = d['data'] + if d["_found_iscf"]: + d["_found_iscf"] = False + data = d["data"] if len(data) == 0: continue @@ -861,11 +911,11 @@ def construct_data(d, data): # case the requested iscf is too big scf = data - if d['_final_iscf'] == 1: - d['_final_iscf'] = 2 - elif d['_final_iscf'] == 2: - d['_final_iscf'] = 0 - data = d['data'] + if d["_final_iscf"] == 1: + d["_final_iscf"] = 2 + elif d["_final_iscf"] == 2: + d["_final_iscf"] = 0 + data = d["data"] if len(data) == 0: # this traps the case where we read ts-scf # but find the final scf iteration. @@ -880,7 +930,7 @@ def construct_data(d, data): continue # First figure out which iscf we should store - if iscf is None: # or iscf > 0 + if iscf is None: # or iscf > 0 # scf is correct pass elif iscf < 0: @@ -905,9 +955,8 @@ def MDstep_dataframe(scf): scf = np.atleast_2d(scf) return pd.DataFrame( scf[..., 1:], - index=pd.Index(scf[..., 0].ravel().astype(np.int32), - name="iscf"), - columns=props[1:] + index=pd.Index(scf[..., 0].ravel().astype(np.int32), name="iscf"), + columns=props[1:], ) # Now we know how many MD steps there are @@ -921,12 +970,14 @@ def MDstep_dataframe(scf): if as_dataframe: if len(md) == 0: # return an empty dataframe (with imd as index) - return pd.DataFrame(index=pd.Index([], name="imd"), - columns=props) + return pd.DataFrame(index=pd.Index([], name="imd"), columns=props) # Regardless of what the user requests we will always have imd == index # and iscf a column, a user may easily change this. - df = pd.concat(map(MDstep_dataframe, md), - keys=_a.arangei(1, len(md) + 1), names=["imd"]) + df = pd.concat( + map(MDstep_dataframe, md), + keys=_a.arangei(1, len(md) + 1), + names=["imd"], + ) if iscf is not None: df.reset_index("iscf", inplace=True) return df @@ -944,15 +995,16 @@ def MDstep_dataframe(scf): if len(md) == 0: # no data collected if as_dataframe: - return pd.DataFrame(index=pd.Index([], name="iscf"), - columns=props[1:]) + return pd.DataFrame(index=pd.Index([], name="iscf"), columns=props[1:]) md = np.array(md[imd]) if ret_header: return md, props return md if imd > len(md): - raise ValueError(f"{self.__class__.__name__}.read_scf could not find requested MD step ({imd}).") + raise ValueError( + f"{self.__class__.__name__}.read_scf could not find requested MD step ({imd})." + ) # If a certain imd was requested, get it # Remember that if imd is positive, we stopped reading at the moment we reached it @@ -964,7 +1016,9 @@ def MDstep_dataframe(scf): return scf @sile_fh_open(True) - def read_charge(self, name, iscf=Opt.ANY, imd=Opt.ANY, key_scf="scf", as_dataframe=False): + def read_charge( + self, name, iscf=Opt.ANY, imd=Opt.ANY, key_scf="scf", as_dataframe=False + ): r"""Read charges calculated in SCF loop or MD loop (or both) Siesta enables many different modes of writing out charges. @@ -1032,19 +1086,23 @@ def read_charge(self, name, iscf=Opt.ANY, imd=Opt.ANY, key_scf="scf", as_datafra namel = name.lower() if as_dataframe: import pandas as pd + def _empty_charge(): # build a fake dataframe with no indices - return pd.DataFrame(index=pd.Index([], name="atom", dtype=np.int32), - dtype=np.float32) + return pd.DataFrame( + index=pd.Index([], name="atom", dtype=np.int32), dtype=np.float32 + ) + else: pd = None + def _empty_charge(): # return for single value with nan values return _a.arrayf([[None]]) # define helper function for reading voronoi+hirshfeld charges def _voronoi_hirshfeld_charges(): - """ Read output from Voronoi/Hirshfeld charges """ + """Read output from Voronoi/Hirshfeld charges""" nonlocal pd # Expecting something like this (NC/SOC) @@ -1057,11 +1115,13 @@ def _voronoi_hirshfeld_charges(): # 1 -0.02936 4.02936 0.00000 C # first line is the header - header = (self.readline() - .replace("dQatom", "dq") # dQatom in master - .replace(" Qatom", " dq") # Qatom in 4.1 - .replace("Atom pop", "e") # not found in 4.1 - .split())[2:-1] + header = ( + self.readline() + .replace("dQatom", "dq") # dQatom in master + .replace(" Qatom", " dq") # Qatom in 4.1 + .replace("Atom pop", "e") # not found in 4.1 + .split() + )[2:-1] # Define the function that parses the charges def _parse_charge(line): @@ -1069,12 +1129,12 @@ def _parse_charge(line): # assert that this is a proper line # this should catch cases where the following line of charge output # is still parseable - #atom_idx = int(atom_idx) + # atom_idx = int(atom_idx) return list(map(float, vals)) # We have found the header, prepare a list to read the charges atom_charges = [] - line = ' ' + line = " " while line != "": try: line = self.readline() @@ -1094,28 +1154,38 @@ def _parse_charge(line): assert ncols == len(header) # the precision is limited, so no need for double precision - return pd.DataFrame(atom_charges, columns=header, dtype=np.float32, - index=pd.RangeIndex(stop=len(atom_charges), name="atom")) + return pd.DataFrame( + atom_charges, + columns=header, + dtype=np.float32, + index=pd.RangeIndex(stop=len(atom_charges), name="atom"), + ) # define helper function for reading voronoi+hirshfeld charges def _mulliken_charges(): - """ Read output from Mulliken charges """ + """Read output from Mulliken charges""" raise NotImplementedError("Mulliken charges are not implemented currently") # Check that a known charge has been requested if namel == "voronoi": _r_charge = _voronoi_hirshfeld_charges - charge_keys = ["Voronoi Atomic Populations", - "Voronoi Net Atomic Populations"] + charge_keys = [ + "Voronoi Atomic Populations", + "Voronoi Net Atomic Populations", + ] elif namel == "hirshfeld": _r_charge = _voronoi_hirshfeld_charges - charge_keys = ["Hirshfeld Atomic Populations", - "Hirshfeld Net Atomic Populations"] + charge_keys = [ + "Hirshfeld Atomic Populations", + "Hirshfeld Net Atomic Populations", + ] elif namel == "mulliken": _r_charge = _mulliken_charges charge_keys = ["mulliken: Atomic and Orbital Populations"] else: - raise ValueError(f"{self.__class__.__name__}.read_charge name argument should be one of [voronoi, hirshfeld, mulliken], got {name}?") + raise ValueError( + f"{self.__class__.__name__}.read_charge name argument should be one of [voronoi, hirshfeld, mulliken], got {name}?" + ) # Ensure the key_scf matches exactly (prepend a space) key_scf = f" {key_scf.strip()}:" @@ -1125,10 +1195,11 @@ def _mulliken_charges(): # to see if we finished a MD read, we check for these keys search_keys = [ # two keys can signal ending SCF - "SCF Convergence", "SCF_NOT_CONV", + "SCF Convergence", + "SCF_NOT_CONV", "siesta: Final energy", key_scf, - *charge_keys + *charge_keys, ] # adjust the below while loop to take into account any additional # segments of search_keys @@ -1136,8 +1207,7 @@ def _mulliken_charges(): IDX_FINAL = [2] IDX_SCF = [3] # the rest are charge keys - IDX_CHARGE = list(range(len(search_keys) - len(charge_keys), - len(search_keys))) + IDX_CHARGE = list(range(len(search_keys) - len(charge_keys), len(search_keys))) # state to figure out where we are state = PropertyDict() @@ -1224,7 +1294,9 @@ def _mulliken_charges(): current_state = state.CHARGE # step to next entry - ret = self.step_to(search_keys, case=True, ret_index=True, allow_reread=False) + ret = self.step_to( + search_keys, case=True, ret_index=True, allow_reread=False + ) if not any((FOUND_SCF, FOUND_MD, FOUND_FINAL)): raise SileError(f"{self!s} does not contain any charges ({name})") @@ -1247,21 +1319,30 @@ def _mulliken_charges(): # convert data to proper data structures # regardless of user requests. This is an overhead... But probably not that big of a problem. if FOUND_SCF: - md_scf_charge = pd.concat([pd.concat(iscf, - keys=pd.RangeIndex(1, len(iscf)+1, name="iscf")) - for iscf in md_scf_charge], - keys=pd.RangeIndex(1, len(md_scf_charge)+1, name="imd")) + md_scf_charge = pd.concat( + [ + pd.concat( + iscf, keys=pd.RangeIndex(1, len(iscf) + 1, name="iscf") + ) + for iscf in md_scf_charge + ], + keys=pd.RangeIndex(1, len(md_scf_charge) + 1, name="imd"), + ) if FOUND_MD: - md_charge = pd.concat(md_charge, keys=pd.RangeIndex(1, len(md_charge)+1, name="imd")) + md_charge = pd.concat( + md_charge, keys=pd.RangeIndex(1, len(md_charge) + 1, name="imd") + ) else: if FOUND_SCF: nan_array = _a.emptyf(md_scf_charge[0][0].shape) nan_array.fill(np.nan) + def get_md_scf_charge(scf_charge, iscf): try: return scf_charge[iscf] except Exception: return nan_array + if FOUND_MD: md_charge = np.stack(md_charge) @@ -1269,7 +1350,7 @@ def get_md_scf_charge(scf_charge, iscf): # So first figure out what is there, and handle this based # on arguments def _p(flag, found): - """ Helper routine to do the following: + """Helper routine to do the following: Returns ------- @@ -1289,7 +1370,7 @@ def _p(flag, found): # flag is only NONE, then pass none if not (Opt.NONE ^ flag): flag = None - else: # not found + else: # not found # we convert flag to none # if ANY or NONE in flag if (Opt.NONE | Opt.ANY) & flag: @@ -1303,8 +1384,7 @@ def _p(flag, found): if not (FOUND_SCF or FOUND_MD): # none of these are found # we request that user does not request any input - if (opt_iscf or (not iscf is None)) or \ - (opt_imd or (not imd is None)): + if (opt_iscf or (not iscf is None)) or (opt_imd or (not imd is None)): raise SileError(f"{self!s} does not contain MD/SCF charges") elif not FOUND_SCF: @@ -1339,11 +1419,15 @@ def _p(flag, found): # this should be handled, i.e. the scf should be taken out if as_dataframe: return md_scf_charge.groupby(level=[0, 2]).nth(iscf) - return np.stack(tuple(get_md_scf_charge(x, iscf) for x in md_scf_charge)) + return np.stack( + tuple(get_md_scf_charge(x, iscf) for x in md_scf_charge) + ) elif FOUND_MD and iscf is None: return md_charge - raise SileError(f"{str(self)} unknown argument for 'imd' and 'iscf', could not find SCF charges") + raise SileError( + f"{str(self)} unknown argument for 'imd' and 'iscf', could not find SCF charges" + ) elif opt_iscf: # flag requested imd @@ -1383,7 +1467,9 @@ def _p(flag, found): return md_scf_charge[imd][iscf] -outSileSiesta = deprecation("outSileSiesta has been deprecated in favor of stdoutSileSiesta.", "0.15")(stdoutSileSiesta) +outSileSiesta = deprecation( + "outSileSiesta has been deprecated in favor of stdoutSileSiesta.", "0.15" +)(stdoutSileSiesta) add_sile("siesta.out", stdoutSileSiesta, case=False, gzip=True) add_sile("out", stdoutSileSiesta, case=False, gzip=True) diff --git a/src/sisl/io/siesta/struct.py b/src/sisl/io/siesta/struct.py index 1498852058..fdca6f0888 100644 --- a/src/sisl/io/siesta/struct.py +++ b/src/sisl/io/siesta/struct.py @@ -10,19 +10,19 @@ from ..sile import add_sile, sile_fh_open, sile_raise_write from .sile import SileSiesta -__all__ = ['structSileSiesta'] +__all__ = ["structSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') +Bohr2Ang = unit_convert("Bohr", "Ang") @set_module("sisl.io.siesta") class structSileSiesta(SileSiesta): - """ Geometry file """ + """Geometry file""" @sile_fh_open() - def write_geometry(self, geometry, fmt='.9f'): - """ Writes the geometry to the contained file + def write_geometry(self, geometry, fmt=".9f"): + """Writes the geometry to the contained file Parameters ---------- @@ -35,15 +35,15 @@ def write_geometry(self, geometry, fmt='.9f'): sile_raise_write(self) # Create format string for the cell-parameters - fmt_str = ' ' + ('{:' + fmt + '} ') * 3 + '\n' + fmt_str = " " + ("{:" + fmt + "} ") * 3 + "\n" for i in range(3): self._write(fmt_str.format(*geometry.cell[i])) - self._write(f'{geometry.na:12d}\n') + self._write(f"{geometry.na:12d}\n") # Create format string for the atomic coordinates fxyz = geometry.fxyz - fmt_str = '{:3d}{:6d} ' - fmt_str += ('{:' + fmt + '} ') * 3 + '\n' + fmt_str = "{:3d}{:6d} " + fmt_str += ("{:" + fmt + "} ") * 3 + "\n" for ia, a, ips in geometry.iter_species(): if isinstance(a, AtomGhost): self._write(fmt_str.format(ips + 1, -a.Z, *fxyz[ia])) @@ -52,7 +52,7 @@ def write_geometry(self, geometry, fmt='.9f'): @sile_fh_open() def read_lattice(self): - """ Returns `Lattice` object from the STRUCT file """ + """Returns `Lattice` object from the STRUCT file""" cell = np.empty([3, 3], np.float64) for i in range(3): @@ -62,7 +62,7 @@ def read_lattice(self): @sile_fh_open() def read_geometry(self, species_Z=False): - """ Returns a `Geometry` object from the STRUCT file + """Returns a `Geometry` object from the STRUCT file Parameters ---------- @@ -108,12 +108,12 @@ def read_geometry(self, species_Z=False): return Geometry(xyz, atms2.reduce(), lattice=lattice) def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('STRUCT_IN', structSileSiesta, gzip=True) -add_sile('STRUCT_NEXT_ITER', structSileSiesta, gzip=True) -add_sile('STRUCT_OUT', structSileSiesta, gzip=True) +add_sile("STRUCT_IN", structSileSiesta, gzip=True) +add_sile("STRUCT_NEXT_ITER", structSileSiesta, gzip=True) +add_sile("STRUCT_OUT", structSileSiesta, gzip=True) diff --git a/src/sisl/io/siesta/tests/test_ani.py b/src/sisl/io/siesta/tests/test_ani.py index 16590a95d3..98ef98550c 100644 --- a/src/sisl/io/siesta/tests/test_ani.py +++ b/src/sisl/io/siesta/tests/test_ani.py @@ -10,11 +10,13 @@ from sisl.io.siesta import aniSileSiesta pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") + def test_ani(sisl_tmp): - f = sisl_tmp('sisl.ANI', _dir) - open(f, 'w').write("""1 + f = sisl_tmp("sisl.ANI", _dir) + open(f, "w").write( + """1 C 0.00000000 0.00000000 0.00000000 2 @@ -32,7 +34,8 @@ def test_ani(sisl_tmp): C 1.000000 0.00000000 0.00000000 C 2.00000 0.00000000 0.00000000 C 3.00000 0.00000000 0.00000000 -""") +""" + ) a = aniSileSiesta(f) g = a.read_geometry[:]() assert len(g) == 4 diff --git a/src/sisl/io/siesta/tests/test_bands.py b/src/sisl/io/siesta/tests/test_bands.py index 65c0a6d6d4..052b5dcdca 100644 --- a/src/sisl/io/siesta/tests/test_bands.py +++ b/src/sisl/io/siesta/tests/test_bands.py @@ -10,33 +10,33 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_fe(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'fe.bands')) + si = sisl.get_sile(sisl_files(_dir, "fe.bands")) labels, k, eig = si.read_data() - assert k.shape == (131, ) + assert k.shape == (131,) assert eig.shape == (131, 2, 15) assert len(labels[0]) == 5 def test_fe_ArgumentParser(sisl_files, sisl_tmp): pytest.importorskip("matplotlib", reason="matplotlib not available") - png = sisl_tmp('fe.bands.png', _dir) - si = sisl.get_sile(sisl_files(_dir, 'fe.bands')) + png = sisl_tmp("fe.bands.png", _dir) + si = sisl.get_sile(sisl_files(_dir, "fe.bands")) p, ns = si.ArgumentParser() p.parse_args([], namespace=ns) - p.parse_args(['--energy', ' -2:2'], namespace=ns) - p.parse_args(['--energy', ' -2:2', '--plot', png], namespace=ns) + p.parse_args(["--energy", " -2:2"], namespace=ns) + p.parse_args(["--energy", " -2:2", "--plot", png], namespace=ns) def test_fe_xarray(sisl_files, sisl_tmp): pytest.importorskip("xarray", reason="xarray not available") - si = sisl.get_sile(sisl_files(_dir, 'fe.bands')) + si = sisl.get_sile(sisl_files(_dir, "fe.bands")) bands = si.read_data(as_dataarray=True) - assert len(bands['k']) == 131 - assert len(bands['spin']) == 2 - assert len(bands['band']) == 15 + assert len(bands["k"]) == 131 + assert len(bands["spin"]) == 2 + assert len(bands["band"]) == 15 assert len(bands.ticks) == len(bands.ticklabels) == 5 diff --git a/src/sisl/io/siesta/tests/test_basis.py b/src/sisl/io/siesta/tests/test_basis.py index 3620d14abc..383a44a24a 100644 --- a/src/sisl/io/siesta/tests/test_basis.py +++ b/src/sisl/io/siesta/tests/test_basis.py @@ -9,12 +9,12 @@ from sisl.io.siesta.basis import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_si_ion_nc(sisl_files): pytest.importorskip("netCDF4") - f = sisl_files(_dir, 'Si.ion.nc') + f = sisl_files(_dir, "Si.ion.nc") with ionncSileSiesta(f) as sile: atom = sile.read_basis() @@ -23,7 +23,7 @@ def test_si_ion_nc(sisl_files): def test_si_ion_xml(sisl_files): - f = sisl_files(_dir, 'Si.ion.xml') + f = sisl_files(_dir, "Si.ion.xml") with ionxmlSileSiesta(f) as sile: atom = sile.read_basis() @@ -32,7 +32,7 @@ def test_si_ion_xml(sisl_files): def test_si_ion_xml_handle(sisl_files): - f = open(sisl_files(_dir, 'Si.ion.xml'), 'r') + f = open(sisl_files(_dir, "Si.ion.xml"), "r") with ionxmlSileSiesta(f) as sile: assert "Buffer" in sile.__class__.__name__ @@ -43,7 +43,7 @@ def test_si_ion_xml_handle(sisl_files): def test_si_ion_xml_stringio(sisl_files): - f = StringIO(open(sisl_files(_dir, 'Si.ion.xml'), 'r').read()) + f = StringIO(open(sisl_files(_dir, "Si.ion.xml"), "r").read()) with ionxmlSileSiesta(f) as sile: assert "Buffer" in sile.__class__.__name__ @@ -55,11 +55,11 @@ def test_si_ion_xml_stringio(sisl_files): def test_si_ion_compare(sisl_files): pytest.importorskip("netCDF4") - f = sisl_files(_dir, 'Si.ion.nc') + f = sisl_files(_dir, "Si.ion.nc") with ionncSileSiesta(f) as sile: nc = sile.read_basis() - f = sisl_files(_dir, 'Si.ion.xml') + f = sisl_files(_dir, "Si.ion.xml") with ionxmlSileSiesta(f) as sile: xml = sile.read_basis() diff --git a/src/sisl/io/siesta/tests/test_dm.py b/src/sisl/io/siesta/tests/test_dm.py index b687e111d0..863ff1552f 100644 --- a/src/sisl/io/siesta/tests/test_dm.py +++ b/src/sisl/io/siesta/tests/test_dm.py @@ -11,27 +11,27 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_dm_si_pdos_kgrid(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf"), base=sisl_files(_dir)) - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.DM')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.DM")) DM1 = si.read_density_matrix(geometry=fdf.read_geometry()) - DM2 = fdf.read_density_matrix(order=['DM']) + DM2 = fdf.read_density_matrix(order=["DM"]) assert DM1._csr.spsame(DM2._csr) assert np.allclose(DM1._csr._D[:, :-1], DM2._csr._D[:, :-1]) def test_dm_si_pdos_kgrid_rw(sisl_files, sisl_tmp): - fdf = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf"), base=sisl_files(_dir)) geom = fdf.read_geometry() - f1 = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.DM')) - f2 = sisl.get_sile(sisl_tmp('test.DM', _dir)) + f1 = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.DM")) + f2 = sisl.get_sile(sisl_tmp("test.DM", _dir)) DM1 = f1.read_density_matrix(geometry=geom) f2.write_density_matrix(DM1, sort=False) @@ -49,11 +49,11 @@ def test_dm_si_pdos_kgrid_rw(sisl_files, sisl_tmp): def test_dm_si_pdos_kgrid_mulliken(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf'), base=sisl_files(_dir)) - DM = fdf.read_density_matrix(order=['DM']) + fdf = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf"), base=sisl_files(_dir)) + DM = fdf.read_density_matrix(order=["DM"]) - Mo = DM.mulliken('orbital') - Ma = DM.mulliken('atom') + Mo = DM.mulliken("orbital") + Ma = DM.mulliken("atom") o2a = DM.geometry.o2a(np.arange(DM.no)) @@ -63,12 +63,12 @@ def test_dm_si_pdos_kgrid_mulliken(sisl_files): def test_dm_soc_pt2_xx_mulliken(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'SOC_Pt2_xx.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "SOC_Pt2_xx.fdf"), base=sisl_files(_dir)) # Force reading a geometry with correct atomic and orbital configuration - DM = fdf.read_density_matrix(order=['DM']) + DM = fdf.read_density_matrix(order=["DM"]) - Mo = DM.mulliken('orbital') - Ma = DM.mulliken('atom') + Mo = DM.mulliken("orbital") + Ma = DM.mulliken("atom") o2a = DM.geometry.o2a(np.arange(DM.no)) @@ -78,8 +78,8 @@ def test_dm_soc_pt2_xx_mulliken(sisl_files): def test_dm_soc_pt2_xx_rw(sisl_files, sisl_tmp): - f1 = sisl.get_sile(sisl_files(_dir, 'SOC_Pt2_xx.DM')) - f2 = sisl.get_sile(sisl_tmp('test.DM', _dir)) + f1 = sisl.get_sile(sisl_files(_dir, "SOC_Pt2_xx.DM")) + f2 = sisl.get_sile(sisl_tmp("test.DM", _dir)) DM1 = f1.read_density_matrix() f2.write_density_matrix(DM1) @@ -91,17 +91,19 @@ def test_dm_soc_pt2_xx_rw(sisl_files, sisl_tmp): assert np.allclose(DM1._csr._D[:, :-1], DM2._csr._D[:, :-1]) -@pytest.mark.xfail(reason="Currently reading a geometry from TSHS does not retain l, m, zeta quantum numbers") +@pytest.mark.xfail( + reason="Currently reading a geometry from TSHS does not retain l, m, zeta quantum numbers" +) def test_dm_soc_pt2_xx_orbital_momentum(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'SOC_Pt2_xx.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "SOC_Pt2_xx.fdf"), base=sisl_files(_dir)) # Force reading a geometry with correct atomic and orbital configuration - DM = fdf.read_density_matrix(order=['DM']) + DM = fdf.read_density_matrix(order=["DM"]) o2a = DM.geometry.o2a(np.arange(DM.no)) # Calculate angular momentum - Lo = DM.orbital_momentum('orbital') - La = DM.orbital_momentum('atom') + Lo = DM.orbital_momentum("orbital") + La = DM.orbital_momentum("atom") la = np.zeros_like(La) np.add.at(la, o2a, Lo.T) diff --git a/src/sisl/io/siesta/tests/test_eig.py b/src/sisl/io/siesta/tests/test_eig.py index 21a56da9fa..cad10a6ff8 100644 --- a/src/sisl/io/siesta/tests/test_eig.py +++ b/src/sisl/io/siesta/tests/test_eig.py @@ -15,8 +15,9 @@ def _convert45(unit): - """ Convert from legacy units to CODATA2018 """ + """Convert from legacy units to CODATA2018""" from sisl.unit.siesta import units, units_legacy + return units(unit) / units_legacy(unit) @@ -27,6 +28,7 @@ def test_si_pdos_kgrid_eig(sisl_files): # nspin, nk, nb assert np.all(eig.shape == (1, 32, 26)) + def test_si_pdos_kgrid_eig_ArgumentParser(sisl_files, sisl_tmp): pytest.importorskip("matplotlib", reason="matplotlib not available") png = sisl_tmp("si_pdos_kgrid.EIG.png", _dir) diff --git a/src/sisl/io/siesta/tests/test_fa.py b/src/sisl/io/siesta/tests/test_fa.py index a834b9466a..7c0e86cf9e 100644 --- a/src/sisl/io/siesta/tests/test_fa.py +++ b/src/sisl/io/siesta/tests/test_fa.py @@ -9,11 +9,11 @@ from sisl.io.siesta.fa import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_si_pdos_kgrid_fa(sisl_files): - f = sisl_files(_dir, 'si_pdos_kgrid.FA') + f = sisl_files(_dir, "si_pdos_kgrid.FA") fa = faSileSiesta(f).read_data() assert len(fa) == 2 @@ -22,10 +22,10 @@ def test_si_pdos_kgrid_fa(sisl_files): def test_read_write_fa(sisl_tmp): - f = sisl_tmp('test.FA', _dir) + f = sisl_tmp("test.FA", _dir) fa = np.random.rand(10, 3) - faSileSiesta(f, 'w').write_force(fa) + faSileSiesta(f, "w").write_force(fa) fa2 = faSileSiesta(f).read_force() assert len(fa) == len(fa2) diff --git a/src/sisl/io/siesta/tests/test_fc.py b/src/sisl/io/siesta/tests/test_fc.py index 47061e3bd8..3ef3dadd72 100644 --- a/src/sisl/io/siesta/tests/test_fc.py +++ b/src/sisl/io/siesta/tests/test_fc.py @@ -10,20 +10,20 @@ from sisl.unit.siesta import unit_convert pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_read_fc(sisl_tmp): - f = sisl_tmp('test.FC', _dir) + f = sisl_tmp("test.FC", _dir) fc = np.random.rand(20, 6, 2, 3) sign = 1 - with open(f, 'w') as fh: - fh.write('sotaeuha 2 0.5\n') + with open(f, "w") as fh: + fh.write("sotaeuha 2 0.5\n") for n in fc: for dx in n: for a in dx * sign: - fh.write('{} {} {}\n'.format(*a)) + fh.write("{} {} {}\n".format(*a)) sign *= -1 fc.shape = (20, 3, 2, 2, 3) @@ -35,18 +35,18 @@ def test_read_fc(sisl_tmp): @pytest.mark.filterwarnings("ignore", message="*assumes displacement=") def test_read_fc_old(sisl_tmp): - f = sisl_tmp('test2.FC', _dir) + f = sisl_tmp("test2.FC", _dir) fc = np.random.rand(20, 6, 2, 3) - with open(f, 'w') as fh: - fh.write('sotaeuha\n') + with open(f, "w") as fh: + fh.write("sotaeuha\n") for n in fc: for dx in n: for a in dx: - fh.write('{} {} {}\n'.format(*a)) + fh.write("{} {} {}\n".format(*a)) fc.shape = (20, 3, 2, 2, 3) - fc2 = fcSileSiesta(f).read_force() / (0.04 * unit_convert('Bohr', 'Ang')) + fc2 = fcSileSiesta(f).read_force() / (0.04 * unit_convert("Bohr", "Ang")) assert fc.shape != fc2.shape fc2 *= np.tile([1, -1], 3).reshape(1, 3, 2, 1, 1) fc2.shape = (-1, 3, 2, 2, 3) @@ -65,7 +65,7 @@ def test_read_fc_old(sisl_tmp): assert np.allclose(fc, fc2) # Specify number of atoms and correction to check they are equivalent - fc2 = fcSileSiesta(f).read_force(-1., na=2) + fc2 = fcSileSiesta(f).read_force(-1.0, na=2) assert fc.shape == fc2.shape fc2 *= np.tile([-1, 1], 3).reshape(1, 3, 2, 1, 1) assert np.allclose(fc, fc2) diff --git a/src/sisl/io/siesta/tests/test_fdf.py b/src/sisl/io/siesta/tests/test_fdf.py index ed5750eb90..4d7453c70e 100644 --- a/src/sisl/io/siesta/tests/test_fdf.py +++ b/src/sisl/io/siesta/tests/test_fdf.py @@ -13,34 +13,36 @@ from sisl.messages import SislWarning from sisl.unit.siesta import unit_convert -pytestmark = [pytest.mark.io, pytest.mark.siesta, pytest.mark.fdf, - pytest.mark.filterwarnings("ignore", message="*number of supercells") +pytestmark = [ + pytest.mark.io, + pytest.mark.siesta, + pytest.mark.fdf, + pytest.mark.filterwarnings("ignore", message="*number of supercells"), ] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_fdf1(sisl_tmp, sisl_system): - f = sisl_tmp('gr.fdf', _dir) - sisl_system.g.write(fdfSileSiesta(f, 'w')) + f = sisl_tmp("gr.fdf", _dir) + sisl_system.g.write(fdfSileSiesta(f, "w")) fdf = fdfSileSiesta(f) str(fdf) with fdf: - fdf.readline() # Be sure that we can read it in a loop - assert fdf.get('LatticeConstant') > 0. - assert fdf.get('LatticeConstant') > 0. - assert fdf.get('LatticeConstant') > 0. + assert fdf.get("LatticeConstant") > 0.0 + assert fdf.get("LatticeConstant") > 0.0 + assert fdf.get("LatticeConstant") > 0.0 fdf.read_lattice() fdf.read_geometry() def test_fdf2(sisl_tmp, sisl_system): - f = sisl_tmp('gr.fdf', _dir) - sisl_system.g.write(fdfSileSiesta(f, 'w')) + f = sisl_tmp("gr.fdf", _dir) + sisl_system.g.write(fdfSileSiesta(f, "w")) g = fdfSileSiesta(f).read_geometry() # Assert they are the same @@ -52,11 +54,11 @@ def test_fdf2(sisl_tmp, sisl_system): def test_fdf_units(sisl_tmp, sisl_system): - f = sisl_tmp('gr.fdf', _dir) - fdf = fdfSileSiesta(f, 'w') + f = sisl_tmp("gr.fdf", _dir) + fdf = fdfSileSiesta(f, "w") g = sisl_system.g - for unit in ['bohr', 'ang', 'fractional', 'frac']: + for unit in ["bohr", "ang", "fractional", "frac"]: fdf.write_geometry(g, unit=unit) g2 = fdfSileSiesta(f).read_geometry() assert np.allclose(g.cell, g2.cell) @@ -67,243 +69,239 @@ def test_fdf_units(sisl_tmp, sisl_system): def test_lattice(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) + f = sisl_tmp("file.fdf", _dir) lines = [ - 'Latticeconstant 1. Ang', - '%block Latticevectors', - ' 1. 1. 1.', - ' 0. 0. 1.', - ' 1. 0. 1.', - '%endblock', + "Latticeconstant 1. Ang", + "%block Latticevectors", + " 1. 1. 1.", + " 0. 0. 1.", + " 1. 0. 1.", + "%endblock", ] - with open(f, 'w') as fh: - fh.write('\n'.join(lines)) + with open(f, "w") as fh: + fh.write("\n".join(lines)) - cell = np.array([[1.]*3, [0, 0, 1], [1, 0, 1]]) + cell = np.array([[1.0] * 3, [0, 0, 1], [1, 0, 1]]) lattice = fdfSileSiesta(f).read_lattice() assert np.allclose(lattice.cell, cell) lines = [ - 'Latticeconstant 1. Bohr', - '%block Latticevectors', - ' 1. 1. 1.', - ' 0. 0. 1.', - ' 1. 0. 1.', - '%endblock', + "Latticeconstant 1. Bohr", + "%block Latticevectors", + " 1. 1. 1.", + " 0. 0. 1.", + " 1. 0. 1.", + "%endblock", ] - with open(f, 'w') as fh: - fh.write('\n'.join(lines)) + with open(f, "w") as fh: + fh.write("\n".join(lines)) lattice = fdfSileSiesta(f).read_lattice() - assert np.allclose(lattice.cell, cell * unit_convert('Bohr', 'Ang')) + assert np.allclose(lattice.cell, cell * unit_convert("Bohr", "Ang")) - cell = np.diag([2.] * 3) + cell = np.diag([2.0] * 3) lines = [ - 'Latticeconstant 2. Ang', - '%block Latticeparameters', - ' 1. 1. 1. 90. 90. 90.', - '%endblock', + "Latticeconstant 2. Ang", + "%block Latticeparameters", + " 1. 1. 1. 90. 90. 90.", + "%endblock", ] - with open(f, 'w') as fh: - fh.write('\n'.join(lines)) + with open(f, "w") as fh: + fh.write("\n".join(lines)) lattice = fdfSileSiesta(f).read_lattice() assert np.allclose(lattice.cell, cell) def test_lattice_fail(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) + f = sisl_tmp("file.fdf", _dir) lines = [ - '%block Latticevectors', - ' 1. 1. 1.', - ' 0. 0. 1.', - ' 1. 0. 1.', - '%endblock', + "%block Latticevectors", + " 1. 1. 1.", + " 0. 0. 1.", + " 1. 0. 1.", + "%endblock", ] - with open(f, 'w') as fh: - fh.write('\n'.join(lines)) + with open(f, "w") as fh: + fh.write("\n".join(lines)) with pytest.raises(SileError): fdfSileSiesta(f).read_lattice() def test_geometry(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) + f = sisl_tmp("file.fdf", _dir) sc_lines = [ - 'Latticeconstant 1. Ang', - '%block latticeparameters', - ' 1. 1. 1. 90. 90. 90.', - '%endblock', + "Latticeconstant 1. Ang", + "%block latticeparameters", + " 1. 1. 1. 90. 90. 90.", + "%endblock", ] lines = [ - 'NumberOfAtoms 2', - '%block chemicalSpeciesLabel', - ' 1 6 C', - ' 2 12 H', - '%endblock', - 'AtomicCoordinatesFormat Ang', - '%block atomiccoordinatesandatomicspecies', - ' 1. 1. 1. 1', - ' 0. 0. 1. 1', - ' 1. 0. 1. 2', - '%endblock', + "NumberOfAtoms 2", + "%block chemicalSpeciesLabel", + " 1 6 C", + " 2 12 H", + "%endblock", + "AtomicCoordinatesFormat Ang", + "%block atomiccoordinatesandatomicspecies", + " 1. 1. 1. 1", + " 0. 0. 1. 1", + " 1. 0. 1. 2", + "%endblock", ] - with open(f, 'w') as fh: - fh.write('\n'.join(sc_lines) + '\n') - fh.write('\n'.join(lines)) + with open(f, "w") as fh: + fh.write("\n".join(sc_lines) + "\n") + fh.write("\n".join(lines)) fdf = fdfSileSiesta(f, base=sisl_tmp.getbase()) g = fdf.read_geometry() assert g.na == 2 - assert np.allclose(g.xyz, [[1.] * 3, - [0, 0, 1]]) + assert np.allclose(g.xyz, [[1.0] * 3, [0, 0, 1]]) assert g.atoms[0].Z == 6 assert g.atoms[1].Z == 6 # default read # of atoms from list - with open(f, 'w') as fh: - fh.write('\n'.join(sc_lines) + '\n') - fh.write('\n'.join(lines[1:])) + with open(f, "w") as fh: + fh.write("\n".join(sc_lines) + "\n") + fh.write("\n".join(lines[1:])) fdf = fdfSileSiesta(f, base=sisl_tmp.getbase()) g = fdf.read_geometry() assert g.na == 3 - assert np.allclose(g.xyz, [[1.] * 3, - [0, 0, 1], - [1, 0, 1]]) + assert np.allclose(g.xyz, [[1.0] * 3, [0, 0, 1], [1, 0, 1]]) assert g.atoms[0].Z == 6 assert g.atoms[1].Z == 6 assert g.atoms[2].Z == 12 def test_re_read(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) - with open(f, 'w') as fh: - fh.write('Flag1 date\n') - fh.write('Flag1 not-date\n') - fh.write('Flag1 not-date-2\n') - fh.write('Flag3 true\n') + f = sisl_tmp("file.fdf", _dir) + with open(f, "w") as fh: + fh.write("Flag1 date\n") + fh.write("Flag1 not-date\n") + fh.write("Flag1 not-date-2\n") + fh.write("Flag3 true\n") fdf = fdfSileSiesta(f) for i in range(10): - assert fdf.get('Flag1') == 'date' - assert fdf.get('Flag3') + assert fdf.get("Flag1") == "date" + assert fdf.get("Flag3") def test_get_set(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) - with open(f, 'w') as fh: - fh.write('Flag1 date\n') + f = sisl_tmp("file.fdf", _dir) + with open(f, "w") as fh: + fh.write("Flag1 date\n") fdf = fdfSileSiesta(f) - assert fdf.get('Flag1') == 'date' - fdf.set('Flag1', 'not-date') - assert fdf.get('Flag1') == 'not-date' - fdf.set('Flag1', 'date') - assert fdf.get('Flag1') == 'date' - fdf.set('Flag1', 'date-date') - assert fdf.get('Flag1') == 'date-date' - fdf.set('Flag1', 'date-date', keep=False) + assert fdf.get("Flag1") == "date" + fdf.set("Flag1", "not-date") + assert fdf.get("Flag1") == "not-date" + fdf.set("Flag1", "date") + assert fdf.get("Flag1") == "date" + fdf.set("Flag1", "date-date") + assert fdf.get("Flag1") == "date-date" + fdf.set("Flag1", "date-date", keep=False) def test_get_block(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) - with open(f, 'w') as fh: - fh.write('%block MyBlock\n date\n%endblock\n') + f = sisl_tmp("file.fdf", _dir) + with open(f, "w") as fh: + fh.write("%block MyBlock\n date\n%endblock\n") fdf = fdfSileSiesta(f) - assert isinstance(fdf.get('MyBlock'), list) - assert fdf.get('MyBlock')[0] == 'date' - assert 'block' in fdf.print("MyBlock", fdf.get("MyBlock")) + assert isinstance(fdf.get("MyBlock"), list) + assert fdf.get("MyBlock")[0] == "date" + assert "block" in fdf.print("MyBlock", fdf.get("MyBlock")) def test_include(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) - with open(f, 'w') as fh: - fh.write('Flag1 date\n') - fh.write('# Flag2 comment\n') - fh.write('Flag2 date2\n') - fh.write('# Flag3 is read through < from file hello\n') - fh.write('Flag3 Sub < hello\n') - fh.write('FakeInt 1\n') - fh.write('Test 1. eV\n') - fh.write(' %INCLUDE file2.fdf\n') - fh.write('TestRy 1. Ry\n') - fh.write('%block Hello < hello\n') - fh.write('\n') - fh.write('TestLast 1. eV\n') - - hello = sisl_tmp('hello', _dir) - with open(hello, 'w') as fh: - fh.write('Flag4 hello\n') - fh.write('# Comments should be discarded\n') - fh.write('Flag3 test\n') - fh.write('Sub sub-test\n') - - file2 = sisl_tmp('file2.fdf', _dir) - with open(file2, 'w') as fh: - fh.write('Flag4 non\n') - fh.write('\n') - fh.write('FakeReal 2.\n') - fh.write(' %incLude file3.fdf') - - file3 = sisl_tmp('file3.fdf', _dir) - with open(file3, 'w') as fh: - fh.write('Sub level\n') - fh.write('Third level\n') - fh.write('MyList [1 , 2 , 3]\n') - + f = sisl_tmp("file.fdf", _dir) + with open(f, "w") as fh: + fh.write("Flag1 date\n") + fh.write("# Flag2 comment\n") + fh.write("Flag2 date2\n") + fh.write("# Flag3 is read through < from file hello\n") + fh.write("Flag3 Sub < hello\n") + fh.write("FakeInt 1\n") + fh.write("Test 1. eV\n") + fh.write(" %INCLUDE file2.fdf\n") + fh.write("TestRy 1. Ry\n") + fh.write("%block Hello < hello\n") + fh.write("\n") + fh.write("TestLast 1. eV\n") + + hello = sisl_tmp("hello", _dir) + with open(hello, "w") as fh: + fh.write("Flag4 hello\n") + fh.write("# Comments should be discarded\n") + fh.write("Flag3 test\n") + fh.write("Sub sub-test\n") + + file2 = sisl_tmp("file2.fdf", _dir) + with open(file2, "w") as fh: + fh.write("Flag4 non\n") + fh.write("\n") + fh.write("FakeReal 2.\n") + fh.write(" %incLude file3.fdf") + + file3 = sisl_tmp("file3.fdf", _dir) + with open(file3, "w") as fh: + fh.write("Sub level\n") + fh.write("Third level\n") + fh.write("MyList [1 , 2 , 3]\n") fdf = fdfSileSiesta(f, base=sisl_tmp.getbase()) assert fdf.includes() == [Path(hello), Path(file2), Path(file3)] - assert fdf.get('Flag1') == 'date' - assert fdf.get('Flag2') == 'date2' - assert fdf.get('Flag3') == 'test' - assert fdf.get('Flag4') == 'non' - assert fdf.get('FLAG4') == 'non' - assert fdf.get('Fakeint') == 1 - assert fdf.get('Fakeint', '0') == '1' - assert fdf.get('Fakereal') == 2. - assert fdf.get('Fakereal', 0.) == 2. - assert fdf.get('test', 'eV') == pytest.approx(1.) - assert fdf.get('test', with_unit=True)[0] == pytest.approx(1.) - assert fdf.get('test', with_unit=True)[1] == 'eV' - assert fdf.get('test', unit='Ry') == pytest.approx(unit_convert('eV', 'Ry')) - assert fdf.get('testRy') == pytest.approx(unit_convert('Ry', 'eV')) - assert fdf.get('testRy', with_unit=True)[0] == pytest.approx(1.) - assert fdf.get('testRy', with_unit=True)[1] == 'Ry' - assert fdf.get('testRy', unit='Ry') == pytest.approx(1.) - assert fdf.get('Sub') == 'sub-test' - assert fdf.get('Third') == 'level' - assert fdf.get('test-last', with_unit=True)[0] == pytest.approx(1.) - assert fdf.get('test-last', with_unit=True)[1] == 'eV' + assert fdf.get("Flag1") == "date" + assert fdf.get("Flag2") == "date2" + assert fdf.get("Flag3") == "test" + assert fdf.get("Flag4") == "non" + assert fdf.get("FLAG4") == "non" + assert fdf.get("Fakeint") == 1 + assert fdf.get("Fakeint", "0") == "1" + assert fdf.get("Fakereal") == 2.0 + assert fdf.get("Fakereal", 0.0) == 2.0 + assert fdf.get("test", "eV") == pytest.approx(1.0) + assert fdf.get("test", with_unit=True)[0] == pytest.approx(1.0) + assert fdf.get("test", with_unit=True)[1] == "eV" + assert fdf.get("test", unit="Ry") == pytest.approx(unit_convert("eV", "Ry")) + assert fdf.get("testRy") == pytest.approx(unit_convert("Ry", "eV")) + assert fdf.get("testRy", with_unit=True)[0] == pytest.approx(1.0) + assert fdf.get("testRy", with_unit=True)[1] == "Ry" + assert fdf.get("testRy", unit="Ry") == pytest.approx(1.0) + assert fdf.get("Sub") == "sub-test" + assert fdf.get("Third") == "level" + assert fdf.get("test-last", with_unit=True)[0] == pytest.approx(1.0) + assert fdf.get("test-last", with_unit=True)[1] == "eV" # Currently lists are not implemented - #assert np.allclose(fdf.get('MyList'), np.arange(3) + 1) - #assert np.allclose(fdf.get('MyList', []), np.arange(3) + 1) + # assert np.allclose(fdf.get('MyList'), np.arange(3) + 1) + # assert np.allclose(fdf.get('MyList', []), np.arange(3) + 1) # Read a block - ll = open(sisl_tmp('hello', _dir)).readlines() + ll = open(sisl_tmp("hello", _dir)).readlines() ll.pop(1) - assert fdf.get('Hello') == [l.replace('\n', '').strip() for l in ll] + assert fdf.get("Hello") == [l.replace("\n", "").strip() for l in ll] def test_xv_preference(sisl_tmp): g = geom.graphene() - g.write(sisl_tmp('file.fdf', _dir)) - g.xyz[0, 0] += 1. - g.write(sisl_tmp('siesta.XV', _dir)) + g.write(sisl_tmp("file.fdf", _dir)) + g.xyz[0, 0] += 1.0 + g.write(sisl_tmp("siesta.XV", _dir)) - lattice = fdfSileSiesta(sisl_tmp('file.fdf', _dir)).read_lattice(True) - g2 = fdfSileSiesta(sisl_tmp('file.fdf', _dir)).read_geometry(True) + lattice = fdfSileSiesta(sisl_tmp("file.fdf", _dir)).read_lattice(True) + g2 = fdfSileSiesta(sisl_tmp("file.fdf", _dir)).read_geometry(True) assert np.allclose(lattice.cell, g.cell) assert np.allclose(g.cell, g2.cell) assert np.allclose(g.xyz, g2.xyz) - g2 = fdfSileSiesta(sisl_tmp('file.fdf', _dir)).read_geometry(order=['fdf']) + g2 = fdfSileSiesta(sisl_tmp("file.fdf", _dir)).read_geometry(order=["fdf"]) assert np.allclose(g.cell, g2.cell) - g2.xyz[0, 0] += 1. + g2.xyz[0, 0] += 1.0 assert np.allclose(g.xyz, g2.xyz) @@ -315,20 +313,20 @@ def test_geom_order(sisl_tmp): gnc = gfdf.copy() gnc.xyz[0, 0] += 0.5 - gfdf.write(sisl_tmp('siesta.fdf', _dir)) + gfdf.write(sisl_tmp("siesta.fdf", _dir)) # Create fdf-file - fdf = fdfSileSiesta(sisl_tmp('siesta.fdf', _dir)) - assert fdf.read_geometry(order=['nc']) is None - gxv.write(sisl_tmp('siesta.XV', _dir)) - gnc.write(sisl_tmp('siesta.nc', _dir)) + fdf = fdfSileSiesta(sisl_tmp("siesta.fdf", _dir)) + assert fdf.read_geometry(order=["nc"]) is None + gxv.write(sisl_tmp("siesta.XV", _dir)) + gnc.write(sisl_tmp("siesta.nc", _dir)) # Should read from XV g = fdf.read_geometry(True) assert np.allclose(g.xyz, gxv.xyz) - g = fdf.read_geometry(order=['nc', 'fdf']) + g = fdf.read_geometry(order=["nc", "fdf"]) assert np.allclose(g.xyz, gnc.xyz) - g = fdf.read_geometry(order=['fdf', 'nc']) + g = fdf.read_geometry(order=["fdf", "nc"]) assert np.allclose(g.xyz, gfdf.xyz) g = fdf.read_geometry(True, order="^fdf") assert np.allclose(g.xyz, gxv.xyz) @@ -336,16 +334,16 @@ def test_geom_order(sisl_tmp): def test_geom_constraints(sisl_tmp): gfdf = geom.graphene().tile(2, 0).tile(2, 1) - gfdf['CONSTRAIN'] = 0 - gfdf['CONSTRAIN-x'] = 2 - gfdf['CONSTRAIN-y'] = [1, 3, 4, 5] - gfdf['CONSTRAIN-z'] = range(len(gfdf)) + gfdf["CONSTRAIN"] = 0 + gfdf["CONSTRAIN-x"] = 2 + gfdf["CONSTRAIN-y"] = [1, 3, 4, 5] + gfdf["CONSTRAIN-z"] = range(len(gfdf)) - gfdf.write(sisl_tmp('siesta.fdf', _dir)) + gfdf.write(sisl_tmp("siesta.fdf", _dir)) def test_h2_dynamical_matrix(sisl_files): - si = fdfSileSiesta(sisl_files(_dir, 'H2_dynamical_matrix.fdf')) + si = fdfSileSiesta(sisl_files(_dir, "H2_dynamical_matrix.fdf")) trans_inv = [True, False] sum0 = trans_inv[:] @@ -355,6 +353,7 @@ def test_h2_dynamical_matrix(sisl_files): hw_true = [-88.392650, -88.392650, -0.000038, -0.000001, 0.000025, 3797.431825] from itertools import product + for ti, s0, herm in product(trans_inv, sum0, hermitian): dyn = si.read_dynamical_matrix(trans_inv=ti, sum0=s0, hermitian=herm) hw = dyn.eigenvalue().hw @@ -365,7 +364,7 @@ def test_h2_dynamical_matrix(sisl_files): def test_dry_read(sisl_tmp): # This test runs the read-functions. They aren't expected to actually read anything, # it is only a dry-run. - file = sisl_tmp('siesta.fdf', _dir) + file = sisl_tmp("siesta.fdf", _dir) geom.graphene().write(file) fdf = fdfSileSiesta(file) @@ -396,18 +395,18 @@ def test_dry_read(sisl_tmp): def test_fdf_argumentparser(sisl_tmp): - f = sisl_tmp('file.fdf', _dir) - with open(f, 'w') as fh: - fh.write('Flag1 date\n') - fh.write('Flag1 not-date\n') - fh.write('Flag1 not-date-2\n') - fh.write('Flag3 true\n') + f = sisl_tmp("file.fdf", _dir) + with open(f, "w") as fh: + fh.write("Flag1 date\n") + fh.write("Flag1 not-date\n") + fh.write("Flag1 not-date-2\n") + fh.write("Flag3 true\n") fdfSileSiesta(f).ArgumentParser() def test_fdf_fe_basis(sisl_files): - geom = fdfSileSiesta(sisl_files(_dir, 'fe.fdf')).read_geometry() + geom = fdfSileSiesta(sisl_files(_dir, "fe.fdf")).read_geometry() assert geom.no == 15 assert geom.na == 1 @@ -471,7 +470,7 @@ def test_fdf_pao_basis(): def test_fdf_gz(sisl_files): - f = sisl_files(osp.join(_dir, 'fdf'), 'main.fdf.gz') + f = sisl_files(osp.join(_dir, "fdf"), "main.fdf.gz") fdf = fdfSileSiesta(f) # read from gzipped file @@ -486,7 +485,7 @@ def test_fdf_gz(sisl_files): assert fdf.get("Lvl3.Foo") == "world3" assert fdf.get("Lvl3.Bar") == "hello3" - f = sisl_files(osp.join(_dir, 'fdf'), 'level2.fdf') + f = sisl_files(osp.join(_dir, "fdf"), "level2.fdf") fdf = fdfSileSiesta(f) # read from non-gzipped file @@ -501,7 +500,7 @@ def test_fdf_gz(sisl_files): @pytest.mark.xfail(reason="wrong set handling for blocks etc") def test_fdf_block_write_print(sisl_tmp): f = sisl_tmp("block_write_print.fdf", _dir) - fdf = fdfSileSiesta(f, 'w') + fdf = fdfSileSiesta(f, "w") block_value = ["this is my life"] fdf.set("hello", block_value) @@ -514,15 +513,20 @@ def test_fdf_block_write_print(sisl_tmp): assert f"""%block hello {block_value[0]} %endblock hello -""" == fdf.print("hello") +""" == fdf.print( + "hello" + ) def test_fdf_write_bandstructure(sisl_tmp, sisl_system): - f = sisl_tmp('gr.fdf', _dir) + f = sisl_tmp("gr.fdf", _dir) - bs = sisl.BandStructure(sisl_system.g, [ - [0, 0, 0], [0.5, 0.5, 0.5], - [0.25, 0.5, 0]], 200, names=["Gamma", "Edge", "L"]) + bs = sisl.BandStructure( + sisl_system.g, + [[0, 0, 0], [0.5, 0.5, 0.5], [0.25, 0.5, 0]], + 200, + names=["Gamma", "Edge", "L"], + ) with fdfSileSiesta(f, "w") as fdf: fdf.write_brillouinzone(bs) @@ -530,4 +534,3 @@ def test_fdf_write_bandstructure(sisl_tmp, sisl_system): with fdfSileSiesta(f) as fdf: block = fdf.get("BandLines") assert len(block) == 3 - diff --git a/src/sisl/io/siesta/tests/test_gf.py b/src/sisl/io/siesta/tests/test_gf.py index c5573c89c4..60235d4cbf 100644 --- a/src/sisl/io/siesta/tests/test_gf.py +++ b/src/sisl/io/siesta/tests/test_gf.py @@ -9,12 +9,12 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_gf_write(sisl_tmp, sisl_system): tb = sisl.Hamiltonian(sisl_system.gtb) - f = sisl_tmp('file.TSGF', _dir) + f = sisl_tmp("file.TSGF", _dir) gf = sisl.io.get_sile(f) bz = sisl.MonkhorstPack(tb, [3, 3, 1]) E = np.linspace(-2, 2, 20) + 1j * 1e-4 @@ -22,7 +22,7 @@ def test_gf_write(sisl_tmp, sisl_system): gf.write_header(bz, E) for i, (ispin, new_hs, k, e) in enumerate(gf): - Hk = tb.Hk(k, format='array') + Hk = tb.Hk(k, format="array") assert ispin == 0 if new_hs and i % 2 == 0: gf.write_hamiltonian(Hk) @@ -33,7 +33,7 @@ def test_gf_write(sisl_tmp, sisl_system): def test_gf_write_read(sisl_tmp, sisl_system): tb = sisl.Hamiltonian(sisl_system.gtb) - f = sisl_tmp('file.TSGF', _dir) + f = sisl_tmp("file.TSGF", _dir) bz = sisl.MonkhorstPack(tb, [3, 3, 1]) E = np.linspace(-2, 2, 20) + 1j * 1e-4 @@ -44,7 +44,7 @@ def test_gf_write_read(sisl_tmp, sisl_system): gf.write_header(bz, E) for i, (ispin, write_hs, k, e) in enumerate(gf): assert ispin == 0 - Hk = tb.Hk(k, format='array') + Hk = tb.Hk(k, format="array") if write_hs and i % 2 == 0: gf.write_hamiltonian(Hk) elif write_hs: @@ -58,7 +58,7 @@ def test_gf_write_read(sisl_tmp, sisl_system): for i, (ispin, write_hs, k, e) in enumerate(gf): assert ispin == 0 - Hk = tb.Hk(k, format='array') + Hk = tb.Hk(k, format="array") if write_hs and i % 2 == 0: Hk_file, _ = gf.read_hamiltonian() elif write_hs: @@ -71,9 +71,9 @@ def test_gf_write_read(sisl_tmp, sisl_system): def test_gf_write_read_spin(sisl_tmp, sisl_system): - f = sisl_tmp('file.TSGF', _dir) + f = sisl_tmp("file.TSGF", _dir) - tb = sisl.Hamiltonian(sisl_system.gtb, spin=sisl.Spin('P')) + tb = sisl.Hamiltonian(sisl_system.gtb, spin=sisl.Spin("P")) tb.construct([(0.1, 1.5), ([0.1, -0.1], [2.7, 1.6])]) bz = sisl.MonkhorstPack(tb, [3, 3, 1]) @@ -84,7 +84,7 @@ def test_gf_write_read_spin(sisl_tmp, sisl_system): gf.write_header(bz, E) for i, (ispin, write_hs, k, e) in enumerate(gf): - Hk = tb.Hk(k, spin=ispin, format='array') + Hk = tb.Hk(k, spin=ispin, format="array") if write_hs and i % 2 == 0: gf.write_hamiltonian(Hk) elif write_hs: @@ -100,7 +100,7 @@ def test_gf_write_read_spin(sisl_tmp, sisl_system): assert np.allclose(k, bz.k) for i, (ispin, write_hs, k, e) in enumerate(gf): - Hk = tb.Hk(k, spin=ispin, format='array') + Hk = tb.Hk(k, spin=ispin, format="array") if write_hs and i % 2 == 0: Hk_file, _ = gf.read_hamiltonian() elif write_hs: @@ -113,9 +113,9 @@ def test_gf_write_read_spin(sisl_tmp, sisl_system): def test_gf_write_read_direct(sisl_tmp, sisl_system): - f = sisl_tmp('file.TSGF', _dir) + f = sisl_tmp("file.TSGF", _dir) - tb = sisl.Hamiltonian(sisl_system.gtb, spin=sisl.Spin('P')) + tb = sisl.Hamiltonian(sisl_system.gtb, spin=sisl.Spin("P")) tb.construct([(0.1, 1.5), ([0.1, -0.1], [2.7, 1.6])]) bz = sisl.MonkhorstPack(tb, [3, 3, 1]) @@ -126,7 +126,7 @@ def test_gf_write_read_direct(sisl_tmp, sisl_system): gf.write_header(bz, E) for i, (ispin, write_hs, k, e) in enumerate(gf): - Hk = tb.Hk(k, spin=ispin, format='array') + Hk = tb.Hk(k, spin=ispin, format="array") if write_hs and i % 2 == 0: gf.write_hamiltonian(Hk) elif write_hs: @@ -196,4 +196,4 @@ def test_gf_write_read_direct(sisl_tmp, sisl_system): def test_gf_sile_error(): with pytest.raises(sisl.SileError): - sisl.get_sile('non_existing_file.TSGF').read_header() + sisl.get_sile("non_existing_file.TSGF").read_header() diff --git a/src/sisl/io/siesta/tests/test_grid.py b/src/sisl/io/siesta/tests/test_grid.py index 3df135aaec..423ce8e5c3 100644 --- a/src/sisl/io/siesta/tests/test_grid.py +++ b/src/sisl/io/siesta/tests/test_grid.py @@ -9,29 +9,29 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_si_pdos_kgrid_grid(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.VT')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.VT")) si.read_grid() - assert si.grid_unit == pytest.approx(sisl.unit.siesta.unit_convert('Ry', 'eV')) + assert si.grid_unit == pytest.approx(sisl.unit.siesta.unit_convert("Ry", "eV")) def test_si_pdos_kgrid_grid_cell(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.VT')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.VT")) si.read_lattice() def test_si_pdos_kgrid_grid_fractions(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.VT')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.VT")) grid = si.read_grid() grid_halve = si.read_grid(index=[0.5]) assert np.allclose(grid.grid * 0.5, grid_halve.grid) def test_si_pdos_kgrid_grid_fdf(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf')) - VT = si.read_grid("VT", order='bin') - TotPot = si.read_grid("totalpotential", order='bin') + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf")) + VT = si.read_grid("VT", order="bin") + TotPot = si.read_grid("totalpotential", order="bin") assert np.allclose(VT.grid, TotPot.grid) diff --git a/src/sisl/io/siesta/tests/test_kp.py b/src/sisl/io/siesta/tests/test_kp.py index 5a39c0450b..dea3453517 100644 --- a/src/sisl/io/siesta/tests/test_kp.py +++ b/src/sisl/io/siesta/tests/test_kp.py @@ -12,10 +12,10 @@ def test_kp_read_write(sisl_tmp): - f = sisl_tmp('tmp.KP') + f = sisl_tmp("tmp.KP") g = geom.graphene() bz = MonkhorstPack(g, [10, 10, 10]) - kpSileSiesta(f, 'w').write_brillouinzone(bz) + kpSileSiesta(f, "w").write_brillouinzone(bz) kpoints, weights = kpSileSiesta(f).read_data(g) assert np.allclose(kpoints, bz.k) @@ -31,10 +31,10 @@ def test_kp_read_write(sisl_tmp): def test_rkp_read_write(sisl_tmp): - f = sisl_tmp('tmp.RKP') + f = sisl_tmp("tmp.RKP") g = geom.graphene() bz = MonkhorstPack(g, [10, 10, 10]) - rkpSileSiesta(f, 'w').write_brillouinzone(bz) + rkpSileSiesta(f, "w").write_brillouinzone(bz) kpoints, weights = rkpSileSiesta(f).read_data() assert np.allclose(kpoints, bz.k) diff --git a/src/sisl/io/siesta/tests/test_orb_indx.py b/src/sisl/io/siesta/tests/test_orb_indx.py index c5595f8131..6f0bf23e49 100644 --- a/src/sisl/io/siesta/tests/test_orb_indx.py +++ b/src/sisl/io/siesta/tests/test_orb_indx.py @@ -9,11 +9,11 @@ from sisl.io.siesta.orb_indx import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_si_pdos_kgrid_orb_indx(sisl_files): - f = sisl_files(_dir, 'si_pdos_kgrid.ORB_INDX') + f = sisl_files(_dir, "si_pdos_kgrid.ORB_INDX") nsc = orbindxSileSiesta(f).read_lattice_nsc() assert np.all(nsc > 1) atoms = orbindxSileSiesta(f).read_basis() @@ -25,7 +25,7 @@ def test_si_pdos_kgrid_orb_indx(sisl_files): def test_sih_orb_indx(sisl_files): - f = sisl_files(_dir, 'sih.ORB_INDX') + f = sisl_files(_dir, "sih.ORB_INDX") nsc = orbindxSileSiesta(f).read_lattice_nsc() assert np.all(nsc == 1) atoms = orbindxSileSiesta(f).read_basis() diff --git a/src/sisl/io/siesta/tests/test_pdos.py b/src/sisl/io/siesta/tests/test_pdos.py index 6be47659cd..77bb6ca37c 100644 --- a/src/sisl/io/siesta/tests/test_pdos.py +++ b/src/sisl/io/siesta/tests/test_pdos.py @@ -11,11 +11,11 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_si_pdos_gamma(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_gamma.PDOS.xml')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_gamma.PDOS.xml")) geom, E, pdos = si.read_data() assert len(geom) == 2 assert len(E) == 500 @@ -24,18 +24,18 @@ def test_si_pdos_gamma(sisl_files): def test_si_pdos_gamma_xarray(sisl_files): pytest.importorskip("xarray", reason="xarray not available") - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_gamma.PDOS.xml')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_gamma.PDOS.xml")) X = si.read_data(as_dataarray=True) assert len(X.geometry) == 2 assert len(X.E) == 500 assert len(X.spin) == 1 - assert X.spin[0] == 'sum' + assert X.spin[0] == "sum" size = np.prod(X.shape[2:]) assert size >= X.geometry.no def test_si_pdos_kgrid(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.PDOS.xml')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.PDOS.xml")) geom, E, pdos = si.read_data() assert len(geom) == 2 assert len(E) == 500 @@ -44,11 +44,11 @@ def test_si_pdos_kgrid(sisl_files): def test_si_pdos_kgrid_xarray(sisl_files): pytest.importorskip("xarray", reason="xarray not available") - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.PDOS.xml')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.PDOS.xml")) X = si.read_data(as_dataarray=True) assert len(X.geometry) == 2 assert len(X.E) == 500 assert len(X.spin) == 1 - assert X.spin[0] == 'sum' + assert X.spin[0] == "sum" size = np.prod(X.shape[2:]) assert size >= X.geometry.no diff --git a/src/sisl/io/siesta/tests/test_siesta.py b/src/sisl/io/siesta/tests/test_siesta.py index ac05d10446..0ce0b448d5 100644 --- a/src/sisl/io/siesta/tests/test_siesta.py +++ b/src/sisl/io/siesta/tests/test_siesta.py @@ -18,16 +18,16 @@ from sisl.io.siesta import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") netCDF4 = pytest.importorskip("netCDF4") def test_nc1(sisl_tmp, sisl_system): - f = sisl_tmp('gr.nc', _dir) + f = sisl_tmp("gr.nc", _dir) tb = Hamiltonian(sisl_system.gtb) tb.construct([sisl_system.R, sisl_system.t]) - with ncSileSiesta(f, 'w') as s: + with ncSileSiesta(f, "w") as s: tb.write(s) with ncSileSiesta(f) as f: @@ -42,10 +42,10 @@ def test_nc1(sisl_tmp, sisl_system): def test_nc2(sisl_tmp, sisl_system): - f = sisl_tmp('grS.nc', _dir) + f = sisl_tmp("grS.nc", _dir) tb = Hamiltonian(sisl_system.gtb, orthogonal=False) tb.construct([sisl_system.R, sisl_system.tS]) - with ncSileSiesta(f, 'w') as s: + with ncSileSiesta(f, "w") as s: tb.write(s) with ncSileSiesta(f) as f: @@ -62,15 +62,15 @@ def test_nc2(sisl_tmp, sisl_system): def test_nc_multiple_fail(sisl_tmp, sisl_system): # writing two different sparse matrices to the same # file will fail - f = sisl_tmp('gr.nc', _dir) + f = sisl_tmp("gr.nc", _dir) H = Hamiltonian(sisl_system.gtb) DM = DensityMatrix(sisl_system.gtb) - with ncSileSiesta(f, 'w') as sile: + with ncSileSiesta(f, "w") as sile: H.construct([sisl_system.R, sisl_system.t]) H.write(sile) - DM[0, 0] = 1. + DM[0, 0] = 1.0 with pytest.raises(ValueError): DM.write(sile) @@ -80,11 +80,11 @@ def test_nc_multiple_fail(sisl_tmp, sisl_system): [True, False], ) def test_nc_multiple_checks(sisl_tmp, sisl_system, sort): - f = sisl_tmp('gr.nc', _dir) + f = sisl_tmp("gr.nc", _dir) H = Hamiltonian(sisl_system.gtb) DM = DensityMatrix(sisl_system.gtb) - with ncSileSiesta(f, 'w') as sile: + with ncSileSiesta(f, "w") as sile: H.construct([sisl_system.R, sisl_system.t]) H.write(sile, sort=sort) @@ -92,9 +92,9 @@ def test_nc_multiple_checks(sisl_tmp, sisl_system, sort): np.random.seed(42) shuffle = np.random.shuffle for io in range(len(H)): - edges = H.edges(io) # get all edges + edges = H.edges(io) # get all edges shuffle(edges) - DM[io, edges] = 2. + DM[io, edges] = 2.0 if not sort: with pytest.raises(ValueError): @@ -104,10 +104,10 @@ def test_nc_multiple_checks(sisl_tmp, sisl_system, sort): def test_nc_overlap(sisl_tmp, sisl_system): - f = sisl_tmp('gr.nc', _dir) + f = sisl_tmp("gr.nc", _dir) tb = Hamiltonian(sisl_system.gtb) tb.construct([sisl_system.R, sisl_system.t]) - tb.write(ncSileSiesta(f, 'w')) + tb.write(ncSileSiesta(f, "w")) with ncSileSiesta(f) as sile: S = sile.read_overlap() @@ -116,7 +116,7 @@ def test_nc_overlap(sisl_tmp, sisl_system): assert np.allclose(S._csr._D.sum(), tb.no) # Write test - f = sisl_tmp('s.nc', _dir) + f = sisl_tmp("s.nc", _dir) with ncSileSiesta(f, "w") as sile: S.write(sile) with ncSileSiesta(f) as sile: @@ -126,12 +126,12 @@ def test_nc_overlap(sisl_tmp, sisl_system): def test_nc_dynamical_matrix(sisl_tmp, sisl_system): - f = sisl_tmp('grdyn.nc', _dir) + f = sisl_tmp("grdyn.nc", _dir) dm = DynamicalMatrix(sisl_system.gtb) for _, ix in dm.iter_orbitals(): - dm[ix, ix] = ix / 2. + dm[ix, ix] = ix / 2.0 - with ncSileSiesta(f, 'w') as sile: + with ncSileSiesta(f, "w") as sile: dm.write(sile) with ncSileSiesta(f) as sile: @@ -146,12 +146,12 @@ def test_nc_dynamical_matrix(sisl_tmp, sisl_system): def test_nc_density_matrix(sisl_tmp, sisl_system): - f = sisl_tmp('grDM.nc', _dir) + f = sisl_tmp("grDM.nc", _dir) dm = DensityMatrix(sisl_system.gtb) for _, ix in dm.iter_orbitals(): - dm[ix, ix] = ix / 2. + dm[ix, ix] = ix / 2.0 - with ncSileSiesta(f, 'w') as sile: + with ncSileSiesta(f, "w") as sile: dm.write(sile) with ncSileSiesta(f) as sile: @@ -166,13 +166,11 @@ def test_nc_density_matrix(sisl_tmp, sisl_system): def test_nc_H_non_colinear(sisl_tmp): - H1 = Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin('NC')) - H1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4], - [0.2, 0.3, 0.4, 0.5]])) + H1 = Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin("NC")) + H1.construct(([0.1, 1.44], [[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5]])) - f1 = sisl_tmp('H1.nc', _dir) - f2 = sisl_tmp('H2.nc', _dir) + f1 = sisl_tmp("H1.nc", _dir) + f2 = sisl_tmp("H2.nc", _dir) H1.write(f1) H1.finalize() with sisl.get_sile(f1) as sile: @@ -187,13 +185,11 @@ def test_nc_H_non_colinear(sisl_tmp): def test_nc_DM_non_colinear(sisl_tmp): - DM1 = DensityMatrix(sisl.geom.graphene(), spin=sisl.Spin('NC')) - DM1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4], - [0.2, 0.3, 0.4, 0.5]])) + DM1 = DensityMatrix(sisl.geom.graphene(), spin=sisl.Spin("NC")) + DM1.construct(([0.1, 1.44], [[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5]])) - f1 = sisl_tmp('DM1.nc', _dir) - f2 = sisl_tmp('DM2.nc', _dir) + f1 = sisl_tmp("DM1.nc", _dir) + f2 = sisl_tmp("DM2.nc", _dir) DM1.write(f1) DM1.finalize() with sisl.get_sile(f1) as sile: @@ -212,13 +208,11 @@ def test_nc_DM_non_colinear(sisl_tmp): def test_nc_EDM_non_colinear(sisl_tmp): - EDM1 = EnergyDensityMatrix(sisl.geom.graphene(), spin=sisl.Spin('NC')) - EDM1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4], - [0.2, 0.3, 0.4, 0.5]])) + EDM1 = EnergyDensityMatrix(sisl.geom.graphene(), spin=sisl.Spin("NC")) + EDM1.construct(([0.1, 1.44], [[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5]])) - f1 = sisl_tmp('EDM1.nc', _dir) - f2 = sisl_tmp('EDM2.nc', _dir) + f1 = sisl_tmp("EDM1.nc", _dir) + f2 = sisl_tmp("EDM2.nc", _dir) EDM1.write(f1, sort=False) EDM1.finalize() with sisl.get_sile(f1) as sile: @@ -238,13 +232,19 @@ def test_nc_EDM_non_colinear(sisl_tmp): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_nc_H_spin_orbit(sisl_tmp): - H1 = Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin('SO')) - H1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) - - f1 = sisl_tmp('H1.nc', _dir) - f2 = sisl_tmp('H2.nc', _dir) + H1 = Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin("SO")) + H1.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) + + f1 = sisl_tmp("H1.nc", _dir) + f2 = sisl_tmp("H2.nc", _dir) H1.write(f1) H1.finalize() with sisl.get_sile(f1) as sile: @@ -260,13 +260,19 @@ def test_nc_H_spin_orbit(sisl_tmp): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_nc_H_spin_orbit_nc2tshs2nc(sisl_tmp): - H1 = Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin('SO')) - H1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) - - f1 = sisl_tmp('H1.nc', _dir) - f2 = sisl_tmp('H2.TSHS', _dir) + H1 = Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin("SO")) + H1.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) + + f1 = sisl_tmp("H1.nc", _dir) + f2 = sisl_tmp("H2.TSHS", _dir) H1.write(f1) H1.finalize() with sisl.get_sile(f1) as sile: @@ -282,13 +288,19 @@ def test_nc_H_spin_orbit_nc2tshs2nc(sisl_tmp): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_nc_DM_spin_orbit(sisl_tmp): - DM1 = DensityMatrix(sisl.geom.graphene(), spin=sisl.Spin('SO')) - DM1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) - - f1 = sisl_tmp('DM1.nc', _dir) - f2 = sisl_tmp('DM2.nc', _dir) + DM1 = DensityMatrix(sisl.geom.graphene(), spin=sisl.Spin("SO")) + DM1.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) + + f1 = sisl_tmp("DM1.nc", _dir) + f2 = sisl_tmp("DM2.nc", _dir) DM1.write(f1) DM1.finalize() with sisl.get_sile(f1) as sile: @@ -304,13 +316,19 @@ def test_nc_DM_spin_orbit(sisl_tmp): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_nc_DM_spin_orbit_nc2dm2nc(sisl_tmp): - DM1 = DensityMatrix(sisl.geom.graphene(), orthogonal=False, spin=sisl.Spin('SO')) - DM1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.]])) - - f1 = sisl_tmp('DM1.nc', _dir) - f2 = sisl_tmp('DM2.DM', _dir) + DM1 = DensityMatrix(sisl.geom.graphene(), orthogonal=False, spin=sisl.Spin("SO")) + DM1.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.0], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.0], + ], + ) + ) + + f1 = sisl_tmp("DM1.nc", _dir) + f2 = sisl_tmp("DM2.DM", _dir) DM1.finalize() DM1.write(f1) with sisl.get_sile(f1) as sile: @@ -325,11 +343,11 @@ def test_nc_DM_spin_orbit_nc2dm2nc(sisl_tmp): def test_nc_ghost(sisl_tmp): - f = sisl_tmp('ghost.nc', _dir) + f = sisl_tmp("ghost.nc", _dir) a1 = Atom(1) am1 = Atom(-1) - g = Geometry([[0., 0., i] for i in range(2)], [a1, am1], 2.) - g.write(ncSileSiesta(f, 'w')) + g = Geometry([[0.0, 0.0, i] for i in range(2)], [a1, am1], 2.0) + g.write(ncSileSiesta(f, "w")) with ncSileSiesta(f) as sile: g2 = sile.read_geometry() diff --git a/src/sisl/io/siesta/tests/test_stdout.py b/src/sisl/io/siesta/tests/test_stdout.py index 019ded090e..e88e798e47 100644 --- a/src/sisl/io/siesta/tests/test_stdout.py +++ b/src/sisl/io/siesta/tests/test_stdout.py @@ -12,11 +12,11 @@ from sisl.io.siesta.stdout import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_md_nose_out(sisl_files): - f = sisl_files(_dir, 'md_nose.out') + f = sisl_files(_dir, "md_nose.out") out = stdoutSileSiesta(f) # nspin, nk, nb @@ -44,14 +44,14 @@ def test_md_nose_out(sisl_files): assert not np.allclose(f0, f) assert np.allclose(f1, f) - #Check that we can read the different types of forces + # Check that we can read the different types of forces nAtoms = 10 atomicF = out.read_force(all=True) totalF = out.read_force(all=True, total=True) maxF = out.read_force(all=True, max=True) assert atomicF.shape == (nOutputs, nAtoms, 3) assert totalF.shape == (nOutputs, 3) - assert maxF.shape == (nOutputs, ) + assert maxF.shape == (nOutputs,) totalF, maxF = out.read_force(total=True, max=True) assert totalF.shape == (3,) assert maxF.shape == () @@ -60,9 +60,9 @@ def test_md_nose_out(sisl_files): s = out.read_stress() assert not np.allclose(s0, s) - sstatic = out.read_stress('static', all=True) - stotal = out.read_stress('total', all=True) - sdata = out.read_data('total', all=True, stress=True) + sstatic = out.read_stress("static", all=True) + stotal = out.read_stress("total", all=True) + sdata = out.read_data("total", all=True, stress=True) for S, T, D in zip(sstatic, stotal, sdata): assert not np.allclose(S, T) @@ -85,7 +85,7 @@ def test_md_nose_out(sisl_files): def test_md_nose_out_data(sisl_files): - f = sisl_files(_dir, 'md_nose.out') + f = sisl_files(_dir, "md_nose.out") out = stdoutSileSiesta(f) f0, g0 = out.read_data(force=True, geometry=True) @@ -100,14 +100,14 @@ def test_md_nose_out_data(sisl_files): def test_md_nose_out_completed(sisl_files): - f = sisl_files(_dir, 'md_nose.out') + f = sisl_files(_dir, "md_nose.out") out = stdoutSileSiesta(f) out.completed() def test_md_nose_out_dataframe(sisl_files): pytest.importorskip("pandas", reason="pandas not available") - f = sisl_files(_dir, 'md_nose.out') + f = sisl_files(_dir, "md_nose.out") out = stdoutSileSiesta(f) data = out.read_scf() @@ -123,7 +123,7 @@ def test_md_nose_out_dataframe(sisl_files): def test_md_nose_out_energy(sisl_files): - f = sisl_files(_dir, 'md_nose.out') + f = sisl_files(_dir, "md_nose.out") energy = stdoutSileSiesta(f).read_energy() assert isinstance(energy, sisl.utils.PropertyDict) assert hasattr(energy, "basis") @@ -132,7 +132,7 @@ def test_md_nose_out_energy(sisl_files): def test_md_nose_pao_basis(sisl_files): - f = sisl_files(_dir, 'md_nose.out') + f = sisl_files(_dir, "md_nose.out") block = """ Mg 1 # Species label, number of l-shells diff --git a/src/sisl/io/siesta/tests/test_stdout_charges.py b/src/sisl/io/siesta/tests/test_stdout_charges.py index 05d0b29c23..f36d2f1edb 100644 --- a/src/sisl/io/siesta/tests/test_stdout_charges.py +++ b/src/sisl/io/siesta/tests/test_stdout_charges.py @@ -12,7 +12,7 @@ from sisl.io.siesta.stdout import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta', 'outs') +_dir = osp.join("sisl", "io", "siesta", "outs") # tests here tests charge reads for output # voronoi + hirshfeld: test_vh_* @@ -27,12 +27,13 @@ def with_pandas(): try: import pandas + return True except ImportError: return False -@pytest.mark.parametrize('name', ("voronoi", "Hirshfeld")) +@pytest.mark.parametrize("name", ("voronoi", "Hirshfeld")) def test_vh_empty_file(name, sisl_files): f = sisl_files(_dir, "voronoi_hirshfeld_4.1_none.out") out = stdoutSileSiesta(f) @@ -50,7 +51,7 @@ def test_vh_empty_file(name, sisl_files): out.read_charge(name, iscf=None, imd=-1) -@pytest.mark.parametrize('name', ("voronoi", "Hirshfeld")) +@pytest.mark.parametrize("name", ("voronoi", "Hirshfeld")) def test_vh_final(name, sisl_files): f = sisl_files(_dir, "voronoi_hirshfeld.out") out = stdoutSileSiesta(f) @@ -76,8 +77,8 @@ def test_vh_final(name, sisl_files): assert np.allclose(df.values, q) -@pytest.mark.parametrize('fname', ("md", "4.1_pol_md", "nc_md")) -@pytest.mark.parametrize('name', ("voronoi", "Hirshfeld")) +@pytest.mark.parametrize("fname", ("md", "4.1_pol_md", "nc_md")) +@pytest.mark.parametrize("name", ("voronoi", "Hirshfeld")) def test_vh_md(name, fname, sisl_files): # voronoi_hirshfeld_md.out f = sisl_files(_dir, f"voronoi_hirshfeld_{fname}.out") @@ -107,8 +108,8 @@ def test_vh_md(name, fname, sisl_files): assert np.allclose(q[-1].ravel(), df.values.ravel()) -@pytest.mark.parametrize('fname', ("md_scf", "nc_md_scf", "pol_md_scf", "soc_md_scf")) -@pytest.mark.parametrize('name', ("voronoi", "Hirshfeld")) +@pytest.mark.parametrize("fname", ("md_scf", "nc_md_scf", "pol_md_scf", "soc_md_scf")) +@pytest.mark.parametrize("name", ("voronoi", "Hirshfeld")) def test_vh_md_scf(name, fname, sisl_files): f = sisl_files(_dir, f"voronoi_hirshfeld_{fname}.out") out = stdoutSileSiesta(f) diff --git a/src/sisl/io/siesta/tests/test_struct.py b/src/sisl/io/siesta/tests/test_struct.py index 556f6f3292..883a371a17 100644 --- a/src/sisl/io/siesta/tests/test_struct.py +++ b/src/sisl/io/siesta/tests/test_struct.py @@ -10,12 +10,12 @@ from sisl.io.siesta.struct import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_struct1(sisl_tmp, sisl_system): - f = sisl_tmp('gr.STRUCT_IN', _dir) - sisl_system.g.write(structSileSiesta(f, 'w')) + f = sisl_tmp("gr.STRUCT_IN", _dir) + sisl_system.g.write(structSileSiesta(f, "w")) g = structSileSiesta(f).read_geometry() # Assert they are the same @@ -25,10 +25,10 @@ def test_struct1(sisl_tmp, sisl_system): def test_struct_reorder(sisl_tmp, sisl_system): - f = sisl_tmp('gr.STRUCT_IN', _dir) + f = sisl_tmp("gr.STRUCT_IN", _dir) g = sisl_system.g.copy() g.atoms[0] = Atom(1) - g.write(structSileSiesta(f, 'w')) + g.write(structSileSiesta(f, "w")) g2 = structSileSiesta(f).read_geometry() # Assert they are the same @@ -38,11 +38,11 @@ def test_struct_reorder(sisl_tmp, sisl_system): def test_struct_ghost(sisl_tmp): - f = sisl_tmp('ghost.STRUCT_IN', _dir) + f = sisl_tmp("ghost.STRUCT_IN", _dir) a1 = Atom(1) am1 = Atom(-1) - g = Geometry([[0., 0., i] for i in range(2)], [a1, am1], 2.) - g.write(structSileSiesta(f, 'w')) + g = Geometry([[0.0, 0.0, i] for i in range(2)], [a1, am1], 2.0) + g.write(structSileSiesta(f, "w")) g2 = structSileSiesta(f).read_geometry() assert np.allclose(g.cell, g2.cell) @@ -54,16 +54,16 @@ def test_struct_ghost(sisl_tmp): def test_si_pdos_kgrid_struct_out(sisl_files): - fdf = get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf')) - struct = get_sile(sisl_files(_dir, 'si_pdos_kgrid.STRUCT_OUT')) + fdf = get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf")) + struct = get_sile(sisl_files(_dir, "si_pdos_kgrid.STRUCT_OUT")) struct_geom = struct.read_geometry() - fdf_geom = fdf.read_geometry(order='STRUCT') + fdf_geom = fdf.read_geometry(order="STRUCT") assert np.allclose(struct_geom.cell, fdf_geom.cell) assert np.allclose(struct_geom.xyz, fdf_geom.xyz) struct_sc = struct.read_lattice() - fdf_sc = fdf.read_lattice(order='STRUCT') + fdf_sc = fdf.read_lattice(order="STRUCT") assert np.allclose(struct_sc.cell, fdf_sc.cell) diff --git a/src/sisl/io/siesta/tests/test_tsde.py b/src/sisl/io/siesta/tests/test_tsde.py index a031811093..479529f953 100644 --- a/src/sisl/io/siesta/tests/test_tsde.py +++ b/src/sisl/io/siesta/tests/test_tsde.py @@ -11,16 +11,16 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_si_pdos_kgrid_tsde_dm(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf"), base=sisl_files(_dir)) - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSDE')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSDE")) DM1 = si.read_density_matrix(geometry=fdf.read_geometry()) - DM2 = fdf.read_density_matrix(order=['TSDE']) + DM2 = fdf.read_density_matrix(order=["TSDE"]) Ef1 = si.read_fermi_level() Ef2 = fdf.read_fermi_level() @@ -32,12 +32,12 @@ def test_si_pdos_kgrid_tsde_dm(sisl_files): def test_si_pdos_kgrid_tsde_edm(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf"), base=sisl_files(_dir)) - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSDE')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSDE")) EDM1 = si.read_energy_density_matrix(geometry=fdf.read_geometry()) - EDM2 = fdf.read_energy_density_matrix(order=['TSDE']) + EDM2 = fdf.read_energy_density_matrix(order=["TSDE"]) assert EDM1._csr.spsame(EDM2._csr) assert np.allclose(EDM1._csr._D[:, :-1], EDM2._csr._D[:, :-1]) @@ -45,14 +45,14 @@ def test_si_pdos_kgrid_tsde_edm(sisl_files): @pytest.mark.filterwarnings("ignore", message="*wrong sparse pattern") def test_si_pdos_kgrid_tsde_dm_edm_rw(sisl_files, sisl_tmp): - fdf = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.fdf"), base=sisl_files(_dir)) geom = fdf.read_geometry() - f1 = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSDE')) + f1 = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSDE")) DM1 = f1.read_density_matrix(geometry=geom) EDM1 = f1.read_energy_density_matrix(geometry=geom) - f2 = sisl.get_sile(sisl_tmp('noEf.TSDE', _dir)) + f2 = sisl.get_sile(sisl_tmp("noEf.TSDE", _dir)) # by default everything gets sorted... f2.write_density_matrices(DM1, EDM1, sort=False) DM2 = f2.read_density_matrix(geometry=geom) @@ -64,9 +64,9 @@ def test_si_pdos_kgrid_tsde_dm_edm_rw(sisl_files, sisl_tmp): # Now the matrices ARE finalized, we don't have to do anything again EDM2 = EDM1.copy() - EDM2.shift(-2., DM1) - f3 = sisl.get_sile(sisl_tmp('Ef.TSDE', _dir)) - f3.write_density_matrices(DM1, EDM2, Ef=-2., sort=False) + EDM2.shift(-2.0, DM1) + f3 = sisl.get_sile(sisl_tmp("Ef.TSDE", _dir)) + f3.write_density_matrices(DM1, EDM2, Ef=-2.0, sort=False) DM3 = f3.read_density_matrix(geometry=geom) EDM3 = f3.read_energy_density_matrix(geometry=geom) assert DM1._csr.spsame(DM3._csr) @@ -74,7 +74,7 @@ def test_si_pdos_kgrid_tsde_dm_edm_rw(sisl_files, sisl_tmp): assert EDM1._csr.spsame(EDM3._csr) assert np.allclose(EDM1._csr._D[:, :-1], EDM3._csr._D[:, :-1]) - f3.write_density_matrices(DM1, EDM2, Ef=-2., sort=True) + f3.write_density_matrices(DM1, EDM2, Ef=-2.0, sort=True) DM3 = f3.read_density_matrix(geometry=geom, sort=False) EDM3 = f3.read_energy_density_matrix(geometry=geom, sort=False) assert DM1._csr.spsame(DM3._csr) diff --git a/src/sisl/io/siesta/tests/test_tshs.py b/src/sisl/io/siesta/tests/test_tshs.py index 732e681a9f..878f33dc8e 100644 --- a/src/sisl/io/siesta/tests/test_tshs.py +++ b/src/sisl/io/siesta/tests/test_tshs.py @@ -11,14 +11,14 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_tshs_si_pdos_kgrid(sisl_files, sisl_tmp): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSHS')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSHS")) assert si.version == 1 HS1 = si.read_hamiltonian() - f = sisl_tmp('tmp.TSHS', _dir) + f = sisl_tmp("tmp.TSHS", _dir) HS1.write(f) si = sisl.get_sile(f) HS2 = si.read_hamiltonian() @@ -30,10 +30,10 @@ def test_tshs_si_pdos_kgrid(sisl_files, sisl_tmp): def test_tshs_si_pdos_kgrid_tofromnc(sisl_files, sisl_tmp): pytest.importorskip("netCDF4") - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSHS')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSHS")) HS1 = si.read_hamiltonian() - f = sisl_tmp('tmp.TSHS', _dir) - fnc = sisl_tmp('tmp.nc', _dir) + f = sisl_tmp("tmp.TSHS", _dir) + fnc = sisl_tmp("tmp.nc", _dir) HS1.write(f) HS1.write(fnc) @@ -50,7 +50,7 @@ def test_tshs_si_pdos_kgrid_tofromnc(sisl_files, sisl_tmp): def test_tshs_si_pdos_kgrid_repeat_tile(sisl_files, sisl_tmp): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSHS')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSHS")) HS = si.read_hamiltonian() HSr = HS.repeat(3, 2).repeat(3, 0).repeat(3, 1) HSt = HS.tile(3, 2).tile(3, 0).tile(3, 1) @@ -58,7 +58,7 @@ def test_tshs_si_pdos_kgrid_repeat_tile(sisl_files, sisl_tmp): def test_tshs_si_pdos_kgrid_repeat_tile_not_used(sisl_files, sisl_tmp): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSHS')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSHS")) HS = si.read_hamiltonian() for i in range(HS.no): HS._csr._extend_empty(i, 3 + i % 3) @@ -68,9 +68,9 @@ def test_tshs_si_pdos_kgrid_repeat_tile_not_used(sisl_files, sisl_tmp): def test_tshs_soc_pt2_xx(sisl_files, sisl_tmp): - fdf = sisl.get_sile(sisl_files(_dir, 'SOC_Pt2_xx.fdf'), base=sisl_files(_dir)) + fdf = sisl.get_sile(sisl_files(_dir, "SOC_Pt2_xx.fdf"), base=sisl_files(_dir)) HS1 = fdf.read_hamiltonian() - f = sisl_tmp('tmp.TSHS', _dir) + f = sisl_tmp("tmp.TSHS", _dir) HS1.write(f) si = sisl.get_sile(f) HS2 = si.read_hamiltonian() @@ -81,44 +81,44 @@ def test_tshs_soc_pt2_xx(sisl_files, sisl_tmp): def test_tshs_soc_pt2_xx_pdos(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'SOC_Pt2_xx.fdf'), base=sisl_files(_dir)) - sc = fdf.read_lattice(order='TSHS') + fdf = sisl.get_sile(sisl_files(_dir, "SOC_Pt2_xx.fdf"), base=sisl_files(_dir)) + sc = fdf.read_lattice(order="TSHS") HS = fdf.read_hamiltonian() assert np.allclose(sc.cell, HS.geometry.lattice.cell) HS.eigenstate().PDOS(np.linspace(-2, 2, 400)) def test_tshs_warn(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSHS')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSHS")) # check number of orbitals geom = si.read_geometry() geom._atoms = sisl.Atoms([sisl.Atom(i + 1) for i in range(geom.na)]) - with pytest.warns(sisl.SislWarning, match='number of orbitals'): + with pytest.warns(sisl.SislWarning, match="number of orbitals"): si.read_hamiltonian(geometry=geom) # check cell geom = si.read_geometry() - geom.lattice.cell[:, :] = 1. - with pytest.warns(sisl.SislWarning, match='lattice vectors'): + geom.lattice.cell[:, :] = 1.0 + with pytest.warns(sisl.SislWarning, match="lattice vectors"): si.read_hamiltonian(geometry=geom) # check atomic coordinates geom = si.read_geometry() - geom.xyz[0, :] += 10. - with pytest.warns(sisl.SislWarning, match='atomic coordinates'): + geom.xyz[0, :] += 10.0 + with pytest.warns(sisl.SislWarning, match="atomic coordinates"): si.read_hamiltonian(geometry=geom) # check supercell geom = si.read_geometry() geom.set_nsc([1, 1, 1]) - with pytest.warns(sisl.SislWarning, match='supercell'): + with pytest.warns(sisl.SislWarning, match="supercell"): si.read_hamiltonian(geometry=geom) def test_tshs_error(sisl_files): # reading with a wrong geometry - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSHS')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSHS")) # check number of orbitals geom = si.read_geometry() @@ -128,7 +128,7 @@ def test_tshs_error(sisl_files): def test_tshs_si_pdos_kgrid_overlap(sisl_files): - si = sisl.get_sile(sisl_files(_dir, 'si_pdos_kgrid.TSHS')) + si = sisl.get_sile(sisl_files(_dir, "si_pdos_kgrid.TSHS")) HS = si.read_hamiltonian() S = si.read_overlap() assert HS._csr.spsame(S._csr) @@ -139,13 +139,19 @@ def test_tshs_si_pdos_kgrid_overlap(sisl_files): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_tshs_spin_orbit(sisl_tmp): - H1 = sisl.Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin('SO')) - H1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) - - f1 = sisl_tmp('tmp1.TSHS', _dir) - f2 = sisl_tmp('tmp2.TSHS', _dir) + H1 = sisl.Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin("SO")) + H1.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) + + f1 = sisl_tmp("tmp1.TSHS", _dir) + f2 = sisl_tmp("tmp2.TSHS", _dir) H1.write(f1) H1.finalize() H2 = sisl.get_sile(f1).read_hamiltonian() @@ -160,25 +166,30 @@ def test_tshs_spin_orbit(sisl_tmp): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_tshs_spin_orbit_tshs2nc2tshs(sisl_tmp): pytest.importorskip("netCDF4") - H1 = sisl.Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin('SO')) - H1.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) - - fdf_file = sisl_tmp('RUN.fdf', _dir) - f1 = sisl_tmp('tmp1.TSHS', _dir) - f2 = sisl_tmp('tmp1.nc', _dir) + H1 = sisl.Hamiltonian(sisl.geom.graphene(), spin=sisl.Spin("SO")) + H1.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) + + fdf_file = sisl_tmp("RUN.fdf", _dir) + f1 = sisl_tmp("tmp1.TSHS", _dir) + f2 = sisl_tmp("tmp1.nc", _dir) H1.write(f1) H1.finalize() H2 = sisl.get_sile(f1).read_hamiltonian() H2.write(f2) H3 = sisl.get_sile(f2).read_hamiltonian() - open(fdf_file, 'w').writelines([ - "SystemLabel tmp1" - ]) + open(fdf_file, "w").writelines(["SystemLabel tmp1"]) fdf = sisl.get_sile(fdf_file) - assert np.allclose(fdf.read_lattice(order='nc').cell, - fdf.read_lattice(order='TSHS').cell) + assert np.allclose( + fdf.read_lattice(order="nc").cell, fdf.read_lattice(order="TSHS").cell + ) assert H1._csr.spsame(H2._csr) assert np.allclose(H1._csr._D, H2._csr._D) assert H1._csr.spsame(H3._csr) @@ -187,14 +198,14 @@ def test_tshs_spin_orbit_tshs2nc2tshs(sisl_tmp): def test_tshs_missing_diagonal(sisl_tmp): H1 = sisl.Hamiltonian(sisl.geom.graphene()) - H1.construct(([0.1, 1.44], [0., -2.7])) + H1.construct(([0.1, 1.44], [0.0, -2.7])) # remove diagonal component here del H1[0, 0] - f1 = sisl_tmp('tmp1.TSHS', _dir) + f1 = sisl_tmp("tmp1.TSHS", _dir) H1.write(f1) - f2 = sisl_tmp('tmp2.TSHS', _dir) + f2 = sisl_tmp("tmp2.TSHS", _dir) H2 = sisl.get_sile(f1).read_hamiltonian() H2.write(f2) H3 = sisl.get_sile(f2).read_hamiltonian() @@ -203,6 +214,6 @@ def test_tshs_missing_diagonal(sisl_tmp): assert not H1._csr.spsame(H2._csr) assert H2._csr.spsame(H3._csr) assert np.allclose(H2._csr._D, H3._csr._D) - H1[0, 0] = 0. + H1[0, 0] = 0.0 H1.finalize() assert H1._csr.spsame(H2._csr) diff --git a/src/sisl/io/siesta/tests/test_wfsx.py b/src/sisl/io/siesta/tests/test_wfsx.py index 24ae0f0ff9..0a71aee29b 100644 --- a/src/sisl/io/siesta/tests/test_wfsx.py +++ b/src/sisl/io/siesta/tests/test_wfsx.py @@ -9,12 +9,14 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_wfsx_read(sisl_files): - fdf = sisl.get_sile(sisl_files(_dir, 'bi2se3_3ql.fdf')) - wfsx = sisl.get_sile(sisl_files(_dir, 'bi2se3_3ql.bands.WFSX'), parent=fdf.read_geometry()) + fdf = sisl.get_sile(sisl_files(_dir, "bi2se3_3ql.fdf")) + wfsx = sisl.get_sile( + sisl_files(_dir, "bi2se3_3ql.bands.WFSX"), parent=fdf.read_geometry() + ) info = wfsx.read_info() sizes = wfsx.read_sizes() diff --git a/src/sisl/io/siesta/tests/test_xv.py b/src/sisl/io/siesta/tests/test_xv.py index 313107169f..b0c8f05d5b 100644 --- a/src/sisl/io/siesta/tests/test_xv.py +++ b/src/sisl/io/siesta/tests/test_xv.py @@ -10,12 +10,12 @@ from sisl.io.siesta.xv import * pytestmark = [pytest.mark.io, pytest.mark.siesta] -_dir = osp.join('sisl', 'io', 'siesta') +_dir = osp.join("sisl", "io", "siesta") def test_xv1(sisl_tmp, sisl_system): - f = sisl_tmp('gr.XV', _dir) - sisl_system.g.write(xvSileSiesta(f, 'w')) + f = sisl_tmp("gr.XV", _dir) + sisl_system.g.write(xvSileSiesta(f, "w")) g = xvSileSiesta(f).read_geometry() # Assert they are the same @@ -25,10 +25,10 @@ def test_xv1(sisl_tmp, sisl_system): def test_xv_reorder(sisl_tmp, sisl_system): - f = sisl_tmp('gr.XV', _dir) + f = sisl_tmp("gr.XV", _dir) g = sisl_system.g.copy() g.atoms[0] = Atom(1) - g.write(xvSileSiesta(f, 'w')) + g.write(xvSileSiesta(f, "w")) g2 = xvSileSiesta(f).read_geometry() # Assert they are the same @@ -38,11 +38,11 @@ def test_xv_reorder(sisl_tmp, sisl_system): def test_xv_velocity(sisl_tmp, sisl_system): - f = sisl_tmp('gr.XV', _dir) + f = sisl_tmp("gr.XV", _dir) g = sisl_system.g.copy() g.atoms[0] = Atom(1) v = np.random.rand(len(g), 3) - g.write(xvSileSiesta(f, 'w'), velocity=v) + g.write(xvSileSiesta(f, "w"), velocity=v) # Try to read in different ways g2 = xvSileSiesta(f).read_geometry() @@ -62,11 +62,11 @@ def test_xv_velocity(sisl_tmp, sisl_system): def test_xv_ghost(sisl_tmp): - f = sisl_tmp('ghost.XV', _dir) + f = sisl_tmp("ghost.XV", _dir) a1 = Atom(1) am1 = Atom(-1) - g = Geometry([[0., 0., i] for i in range(2)], [a1, am1], 2.) - g.write(xvSileSiesta(f, 'w')) + g = Geometry([[0.0, 0.0, i] for i in range(2)], [a1, am1], 2.0) + g.write(xvSileSiesta(f, "w")) g2 = xvSileSiesta(f).read_geometry() assert np.allclose(g.cell, g2.cell) diff --git a/src/sisl/io/siesta/transiesta_grid.py b/src/sisl/io/siesta/transiesta_grid.py index 6535c45069..50c17cd115 100644 --- a/src/sisl/io/siesta/transiesta_grid.py +++ b/src/sisl/io/siesta/transiesta_grid.py @@ -11,16 +11,16 @@ from .siesta_grid import gridncSileSiesta from .sile import SileCDFSiesta -__all__ = ['tsvncSileSiesta'] +__all__ = ["tsvncSileSiesta"] -_eV2Ry = unit_convert('eV', 'Ry') -_Ry2eV = 1. / _eV2Ry +_eV2Ry = unit_convert("eV", "Ry") +_Ry2eV = 1.0 / _eV2Ry @set_module("sisl.io.siesta") class tsvncSileSiesta(gridncSileSiesta): - """ TranSiesta potential input Grid file object + """TranSiesta potential input Grid file object This potential input file is mainly intended for the Hartree solution which complements N-electrode calculations in TranSiesta. @@ -31,15 +31,15 @@ class tsvncSileSiesta(gridncSileSiesta): """ def read_grid(self, *args, **kwargs): - """ Reads the TranSiesta potential input grid """ + """Reads the TranSiesta potential input grid""" lattice = self.read_lattice().swapaxes(0, 2) # Create the grid - na = len(self._dimension('a')) - nb = len(self._dimension('b')) - nc = len(self._dimension('c')) + na = len(self._dimension("a")) + nb = len(self._dimension("b")) + nc = len(self._dimension("c")) - v = self._variable('V') + v = self._variable("V") # Create the grid, Siesta uses periodic, always grid = Grid([nc, nb, na], bc=Grid.PERIODIC, lattice=lattice, dtype=v.dtype) @@ -51,30 +51,34 @@ def read_grid(self, *args, **kwargs): return grid.swapaxes(0, 2) def write_grid(self, grid): - """ Write the Poisson solution to the TSV.nc file """ + """Write the Poisson solution to the TSV.nc file""" sile_raise_write(self) self.write_lattice(grid.lattice) - self._crt_dim(self, 'one', 1) - self._crt_dim(self, 'a', grid.shape[0]) - self._crt_dim(self, 'b', grid.shape[1]) - self._crt_dim(self, 'c', grid.shape[2]) - - vmin = self._crt_var(self, 'Vmin', 'f8', ('one',)) - vmin.info = 'Minimum value in the Poisson solution (for TranSiesta interpolation)' - vmin.unit = 'Ry' - vmax = self._crt_var(self, 'Vmax', 'f8', ('one',)) - vmax.info = 'Maximum value in the Poisson solution (for TranSiesta interpolation)' - vmax.unit = 'Ry' - - v = self._crt_var(self, 'V', grid.dtype, ('c', 'b', 'a')) - v.info = 'Poisson solution with custom boundary conditions' - v.unit = 'Ry' + self._crt_dim(self, "one", 1) + self._crt_dim(self, "a", grid.shape[0]) + self._crt_dim(self, "b", grid.shape[1]) + self._crt_dim(self, "c", grid.shape[2]) + + vmin = self._crt_var(self, "Vmin", "f8", ("one",)) + vmin.info = ( + "Minimum value in the Poisson solution (for TranSiesta interpolation)" + ) + vmin.unit = "Ry" + vmax = self._crt_var(self, "Vmax", "f8", ("one",)) + vmax.info = ( + "Maximum value in the Poisson solution (for TranSiesta interpolation)" + ) + vmax.unit = "Ry" + + v = self._crt_var(self, "V", grid.dtype, ("c", "b", "a")) + v.info = "Poisson solution with custom boundary conditions" + v.unit = "Ry" vmin[:] = grid.grid.min() * _eV2Ry vmax[:] = grid.grid.max() * _eV2Ry v[:, :, :] = np.swapaxes(grid.grid, 0, 2) * _eV2Ry -add_sile('TSV.nc', tsvncSileSiesta) +add_sile("TSV.nc", tsvncSileSiesta) diff --git a/src/sisl/io/siesta/xv.py b/src/sisl/io/siesta/xv.py index 9937844413..bec19e73f6 100644 --- a/src/sisl/io/siesta/xv.py +++ b/src/sisl/io/siesta/xv.py @@ -10,19 +10,19 @@ from ..sile import SileError, add_sile, sile_fh_open, sile_raise_write from .sile import SileSiesta -__all__ = ['xvSileSiesta'] +__all__ = ["xvSileSiesta"] -Bohr2Ang = unit_convert('Bohr', 'Ang') +Bohr2Ang = unit_convert("Bohr", "Ang") @set_module("sisl.io.siesta") class xvSileSiesta(SileSiesta): - """ Geometry file """ + """Geometry file""" @sile_fh_open() - def write_geometry(self, geometry, fmt='.9f', velocity=None): - """ Writes the geometry to the contained file + def write_geometry(self, geometry, fmt=".9f", velocity=None): + """Writes the geometry to the contained file Parameters ---------- @@ -40,23 +40,25 @@ def write_geometry(self, geometry, fmt='.9f', velocity=None): if velocity is None: velocity = np.zeros([geometry.na, 3], np.float32) if geometry.xyz.shape != velocity.shape: - raise SileError(f'{self}.write_geometry requires the input' - 'velocity to have equal length to the input geometry.') + raise SileError( + f"{self}.write_geometry requires the input" + "velocity to have equal length to the input geometry." + ) # Write unit-cell tmp = np.zeros(6, np.float64) # Create format string for the cell-parameters - fmt_str = (' ' + ('{:' + fmt + '} ') * 3) * 2 + '\n' + fmt_str = (" " + ("{:" + fmt + "} ") * 3) * 2 + "\n" for i in range(3): tmp[0:3] = geometry.cell[i, :] / Bohr2Ang self._write(fmt_str.format(*tmp)) - self._write(f'{geometry.na:12d}\n') + self._write(f"{geometry.na:12d}\n") # Create format string for the atomic coordinates - fmt_str = '{:3d}{:6d} ' - fmt_str += ('{:' + fmt + '} ') * 3 + ' ' - fmt_str += ('{:' + fmt + '} ') * 3 + '\n' + fmt_str = "{:3d}{:6d} " + fmt_str += ("{:" + fmt + "} ") * 3 + " " + fmt_str += ("{:" + fmt + "} ") * 3 + "\n" for ia, a, ips in geometry.iter_species(): tmp[0:3] = geometry.xyz[ia, :] / Bohr2Ang tmp[3:] = velocity[ia, :] / Bohr2Ang @@ -67,7 +69,7 @@ def write_geometry(self, geometry, fmt='.9f', velocity=None): @sile_fh_open() def read_lattice(self): - """ Returns `Lattice` object from the XV file """ + """Returns `Lattice` object from the XV file""" cell = np.empty([3, 3], np.float64) for i in range(3): @@ -78,7 +80,7 @@ def read_lattice(self): @sile_fh_open() def read_geometry(self, velocity=False, species_Z=False): - """ Returns a `Geometry` object from the XV file + """Returns a `Geometry` object from the XV file Parameters ---------- @@ -134,11 +136,11 @@ def read_geometry(self, velocity=False, species_Z=False): @sile_fh_open() def read_velocity(self): - """ Returns an array with the velocities from the XV file + """Returns an array with the velocities from the XV file Returns ------- - velocity : + velocity : """ self.read_lattice() na = int(self.readline()) @@ -153,10 +155,10 @@ def read_velocity(self): read_data = read_velocity def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('XV', xvSileSiesta, gzip=True) +add_sile("XV", xvSileSiesta, gzip=True) diff --git a/src/sisl/io/sile.py b/src/sisl/io/sile.py index 3a7c46b1e0..de0a121e66 100644 --- a/src/sisl/io/sile.py +++ b/src/sisl/io/sile.py @@ -20,31 +20,21 @@ from ._help import * # Public used objects -__all__ = [ - 'add_sile', - 'get_sile_class', - 'get_sile', - 'get_siles', - 'get_sile_rules' -] +__all__ = ["add_sile", "get_sile_class", "get_sile", "get_siles", "get_sile_rules"] __all__ += [ - 'BaseSile', - 'BufferSile', - 'Sile', - 'SileCDF', - 'SileBin', - 'SileError', - 'SileWarning', - 'SileInfo', + "BaseSile", + "BufferSile", + "Sile", + "SileCDF", + "SileBin", + "SileError", + "SileWarning", + "SileInfo", ] # Decorators or sile-specific functions -__all__ += [ - 'sile_fh_open', - 'sile_raise_write', - 'sile_raise_read' -] +__all__ += ["sile_fh_open", "sile_raise_write", "sile_raise_read"] # Global container of all Sile rules # This list of tuples is formed as @@ -60,7 +50,7 @@ class _sile_rule: - """ Internal data-structure to check whether a file is the same as this sile """ + """Internal data-structure to check whether a file is the same as this sile""" COMPARISONS = { "contains": contains, @@ -69,7 +59,7 @@ class _sile_rule: "startswith": str.startswith, } - __slots__ = ('cls', 'case', 'suffix', 'gzip', 'bases', 'base_names') + __slots__ = ("cls", "case", "suffix", "gzip", "bases", "base_names") def __init__(self, cls, suffix, case=True, gzip=False): self.cls = cls @@ -83,18 +73,21 @@ def __init__(self, cls, suffix, case=True, gzip=False): self.base_names = [c.__name__.lower() for c in self.bases] def __str__(self): - s = '{cls}{{case={case}, suffix={suffix}, gzip={gzip},\n '.format(cls=self.cls.__name__, case=self.case, - suffix=self.suffix, gzip=self.gzip) + s = "{cls}{{case={case}, suffix={suffix}, gzip={gzip},\n ".format( + cls=self.cls.__name__, case=self.case, suffix=self.suffix, gzip=self.gzip + ) for b in self.bases: - s += f' {b.__name__},\n ' - return s[:-3] + '\n}' + s += f" {b.__name__},\n " + return s[:-3] + "\n}" def __repr__(self): - return (f"<{self.cls.__name__}, case={self.case}, " - f"suffix={self.suffix}, gzip={self.gzip}>") + return ( + f"<{self.cls.__name__}, case={self.case}, " + f"suffix={self.suffix}, gzip={self.gzip}>" + ) def build_bases(self): - """ Return a list of all classes that this file is inheriting from (except Sile, SileBin or SileCDF) """ + """Return a list of all classes that this file is inheriting from (except Sile, SileBin or SileCDF)""" children = list(self.cls.__bases__) + [self.cls] nl = -1 while len(children) != nl: @@ -106,7 +99,7 @@ def build_bases(self): children.pop(i) except Exception: pass - for child in list(children): # ensure we have a copy for infinite loops + for child in list(children): # ensure we have a copy for infinite loops for c in child.__bases__: if c not in children: children.append(c) @@ -114,7 +107,7 @@ def build_bases(self): return children def in_bases(self, base, method="contains"): - """ Whether any of the inherited bases compares with `base` in their class-name (lower-case sensitive) """ + """Whether any of the inherited bases compares with `base` in their class-name (lower-case sensitive)""" if base is None: return True elif isinstance(base, object): @@ -126,7 +119,7 @@ def in_bases(self, base, method="contains"): return False def get_base(self, base, method="contains"): - """ Whether any of the inherited bases compares with `base` in their class-name (lower-case sensitive) """ + """Whether any of the inherited bases compares with `base` in their class-name (lower-case sensitive)""" if base is None: return None comparison = self.COMPARISONS[method] @@ -136,7 +129,7 @@ def get_base(self, base, method="contains"): return None def in_class(self, base, method="contains"): - """ Whether any of the inherited bases compares with `base` in their class-name (lower-case sensitive) """ + """Whether any of the inherited bases compares with `base` in their class-name (lower-case sensitive)""" if base is None: return False comparison = self.COMPARISONS[method] @@ -166,7 +159,7 @@ def is_subclass(self, cls): @set_module("sisl.io") def add_sile(suffix, cls, case=True, gzip=False): - """ Add files to the global lookup table + """Add files to the global lookup table Public for attaching lookup tables for allowing users to attach files externally. @@ -191,7 +184,7 @@ def add_sile(suffix, cls, case=True, gzip=False): raise ValueError(f"Class {cls.__name__} must be a subclass of BaseSile!") # Only add pure suffixes... - if suffix.startswith('.'): + if suffix.startswith("."): suffix = suffix[1:] # If it isn't already in the list of @@ -205,7 +198,7 @@ def add_sile(suffix, cls, case=True, gzip=False): @set_module("sisl.io") def get_sile_class(filename, *args, **kwargs): - """ Retrieve a class from the global lookup table via filename and the extension + """Retrieve a class from the global lookup table via filename and the extension Parameters ---------- @@ -227,7 +220,7 @@ def get_sile_class(filename, *args, **kwargs): global __sile_rules, __siles # This ensures that the first argument need not be cls - cls = kwargs.pop('cls', None) + cls = kwargs.pop("cls", None) # Split filename into proper file name and # the Specification of the type @@ -244,14 +237,15 @@ def get_sile_class(filename, *args, **kwargs): if "=" in specification: method, cls_search = specification.split("=", 1) if "=" in cls_search: - raise ValueError(f"Comparison specification currently only supports one level of comparison(single =); got {specification}") + raise ValueError( + f"Comparison specification currently only supports one level of comparison(single =); got {specification}" + ) else: method, cls_search = "contains", specification # searchable rules eligible_rules = [] if cls is None and not cls_search is None: - # cls has not been set, and fcls is found # Figure out if fcls is a valid sile, if not # do nothing (it may be part of the file name) @@ -265,7 +259,9 @@ def get_sile_class(filename, *args, **kwargs): # we have at least one eligible rule filename = tmp_file else: - warn(f"Specification requirement of the file did not result in any found files: {specification}") + warn( + f"Specification requirement of the file did not result in any found files: {specification}" + ) else: # search everything @@ -282,10 +278,11 @@ def get_sile_class(filename, *args, **kwargs): # Create list of endings on this file f = basename(filename) end_list = [] - end = '' + end = "" def try_methods(eligibles, prefixes=("read_",)): - """ return only those who can actually perform the read actions """ + """return only those who can actually perform the read actions""" + def has(keys): nonlocal prefixes has_keys = [] @@ -317,7 +314,7 @@ def has(keys): lext = splitext(f) while len(lext[1]) > 0: end = lext[1] + end - if end[0] == '.': + if end[0] == ".": end_list.append(end[1:]) else: end_list.append(end) @@ -346,18 +343,22 @@ def get_eligibles(end, rules): # First we check for class AND file ending for end, rules in product(end_list, (eligible_rules, __sile_rules)): - eligibles = get_eligibles(end, rules) - # Determine whether we have found a compatible sile - if len(eligibles) == 1: + eligibles = get_eligibles(end, rules) + # Determine whether we have found a compatible sile + if len(eligibles) == 1: return eligibles[0].cls - elif len(eligibles) > 1: + elif len(eligibles) > 1: workable_eligibles = try_methods(eligibles) if len(workable_eligibles) == 1: return workable_eligibles[0].cls - raise ValueError(f"Cannot determine the exact Sile requested, multiple hits: {tuple(e.cls.__name__ for e in eligibles)}") + raise ValueError( + f"Cannot determine the exact Sile requested, multiple hits: {tuple(e.cls.__name__ for e in eligibles)}" + ) - raise NotImplementedError(f"Sile for file '{filename}' could not be found, " - "possibly the file has not been implemented.") + raise NotImplementedError( + f"Sile for file '{filename}' could not be found, " + "possibly the file has not been implemented." + ) except Exception as e: raise e @@ -365,7 +366,7 @@ def get_eligibles(end, rules): @set_module("sisl.io") def get_sile(file, *args, **kwargs): - """ Retrieve an object from the global lookup table via filename and the extension + """Retrieve an object from the global lookup table via filename and the extension Internally this is roughly equivalent to ``get_sile_class(...)()``. @@ -407,14 +408,14 @@ class to construct the file object, see examples. >>> cls = get_sile_class("water.dat{startswith=xyz}") """ - cls = kwargs.pop('cls', None) + cls = kwargs.pop("cls", None) sile = get_sile_class(file, *args, cls=cls, **kwargs) return sile(Path(str_spec(str(file))[0]), *args, **kwargs) @set_module("sisl.io") def get_siles(attrs=None): - """ Retrieve all files with specific attributes or methods + """Retrieve all files with specific attributes or methods Parameters ---------- @@ -478,29 +479,29 @@ def get_sile_rules(attrs=None, cls=None): @set_module("sisl.io") class BaseSile: - """ Base class for all sisl files """ + """Base class for all sisl files""" def __init__(self, *args, **kwargs): - """ Just to pass away the args and kwargs """ + """Just to pass away the args and kwargs""" @property def file(self): - """ File of the current `Sile` """ + """File of the current `Sile`""" return self._file @property def base_file(self): - """ File of the current `Sile` """ + """File of the current `Sile`""" return basename(self._file) def dir_file(self, filename=None, filename_base=""): - """ File of the current `Sile` """ + """File of the current `Sile`""" if filename is None: filename = Path(self._file).name return self._directory / filename_base / filename def read(self, *args, **kwargs): - """ Generic read method which should be overloaded in child-classes + """Generic read method which should be overloaded in child-classes Parameters ---------- @@ -523,7 +524,7 @@ def read(self, *args, **kwargs): _write_default_only = False def write(self, *args, **kwargs): - """ Generic write method which should be overloaded in child-classes + """Generic write method which should be overloaded in child-classes Parameters ---------- @@ -544,7 +545,7 @@ def write(self, *args, **kwargs): func(kwargs[key], **kwargs) def _setup(self, *args, **kwargs): - """ Setup the `Sile` after initialization + """Setup the `Sile` after initialization Inherited method for setting up the sile. @@ -553,59 +554,69 @@ def _setup(self, *args, **kwargs): pass def _base_setup(self, *args, **kwargs): - """ Setup the `Sile` after initialization + """Setup the `Sile` after initialization Inherited method for setting up the sile. This method **must** be overwritten *and* end with ``self._setup()``. """ - base = kwargs.get('base', None) + base = kwargs.get("base", None) if base is None: # Extract from filename self._directory = Path(self._file).parent else: self._directory = base if not str(self._directory): - self._directory = '.' + self._directory = "." self._directory = Path(self._directory).resolve() self._setup(*args, **kwargs) def _base_file(self, f): - """ Make `f` refer to the file with the appropriate base directory """ + """Make `f` refer to the file with the appropriate base directory""" return self._directory / f def __getattr__(self, name): - """ Override to check the handle """ - if name == 'fh': - raise AttributeError(f"The filehandle for {self.file} has not been opened yet...") + """Override to check the handle""" + if name == "fh": + raise AttributeError( + f"The filehandle for {self.file} has not been opened yet..." + ) if name == "read_supercell" and hasattr(self, "read_lattice"): - deprecate(f"{self.__class__.__name__}.read_supercell is deprecated in favor of read_lattice", "0.15") + deprecate( + f"{self.__class__.__name__}.read_supercell is deprecated in favor of read_lattice", + "0.15", + ) return getattr(self, "read_lattice") if name == "write_supercell" and hasattr(self, "write_lattice"): - deprecate(f"{self.__class__.__name__}.write_supercell is deprecated in favor of write_lattice", "0.15") + deprecate( + f"{self.__class__.__name__}.write_supercell is deprecated in favor of write_lattice", + "0.15", + ) return getattr(self, "write_lattice") return getattr(self.fh, name) @classmethod def _ArgumentParser_args_single(cls): - """ Default arguments for the Sile """ + """Default arguments for the Sile""" return {} # Define the custom ArgumentParser def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that may be available for this Sile + """Returns the arguments that may be available for this Sile Parameters ---------- p : ArgumentParser the argument parser to add the arguments to. """ - raise NotImplementedError(f"The ArgumentParser of '{self.__class__.__name__}' has not been implemented yet.") + raise NotImplementedError( + f"The ArgumentParser of '{self.__class__.__name__}' has not been implemented yet." + ) def ArgumentParser_out(self, p=None, *args, **kwargs): - """ Appends additional arguments based on the output of the file + """Appends additional arguments based on the output of the file Parameters ---------- @@ -615,19 +626,20 @@ def ArgumentParser_out(self, p=None, *args, **kwargs): pass def __str__(self): - """ Return a representation of the `Sile` """ + """Return a representation of the `Sile`""" # Check if the directory is relative to the current path # If so, only print the relative path, otherwise print the full path d = self._directory try: # bypass d.is_relative_to, added in 3.9 - d = d.relative_to(Path('.').resolve()) - except Exception: pass + d = d.relative_to(Path(".").resolve()) + except Exception: + pass return f"{self.__class__.__name__}({self.base_file!s}, base={d!s})" def sile_fh_open(from_closed=False, reset=None): - """ Method decorator for objects to directly implement opening of the + """Method decorator for objects to directly implement opening of the file-handle upon entry (if it isn't already). Parameters @@ -639,26 +651,37 @@ def sile_fh_open(from_closed=False, reset=None): ``reset(self)`` """ if reset is None: - def reset(self): pass + + def reset(self): + pass if from_closed: + def _wrapper(func): nonlocal reset + @wraps(func) def pre_open(self, *args, **kwargs): # only call reset if the file should be reset _reset = reset if hasattr(self, "fh"): - def _reset(self): pass + + def _reset(self): + pass + with self: # REMARK this requires the __enter__ to seek(0) # for the file, and currently it does _reset(self) return func(self, *args, **kwargs) + return pre_open + else: + def _wrapper(func): nonlocal reset + @wraps(func) def pre_open(self, *args, **kwargs): if hasattr(self, "fh"): @@ -666,13 +689,15 @@ def pre_open(self, *args, **kwargs): with self: reset(self) return func(self, *args, **kwargs) + return pre_open + return _wrapper @set_module("sisl.io") class BufferSile: - """ Sile for handling `StringIO` and `TextIOBase` objects + """Sile for handling `StringIO` and `TextIOBase` objects These are basically meant for users passing down the above objects """ @@ -713,13 +738,13 @@ def __exit__(self, type, value, traceback): return False def close(self): - """ Will not close the file since this is passed by the user """ + """Will not close the file since this is passed by the user""" pass @set_module("sisl.io") class Info: - """ An info class that creates .info with inherent properties + """An info class that creates .info with inherent properties These properties can be added at will. """ @@ -732,7 +757,7 @@ def __init__(self, *args, **kwargs): self.info = _Info(self) class _Info: - """ The actual .info object that will attached to the instance. + """The actual .info object that will attached to the instance. As of now this is problematic to document. We should figure out a way to do that. @@ -759,6 +784,7 @@ def readline(*args, **kwargs): for prop in properties: prop.process(line) return line + return readline self._instance.readline = patch(self) @@ -772,19 +798,21 @@ def readline(*args, **kwargs): self.add_property(prop) def add_property(self, prop): - """ Add a new property to be reachable from the .info """ + """Add a new property to be reachable from the .info""" self._attrs.append(prop.attr) self._properties.append(prop) def __str__(self): - """ Return a string of the contained attributes, with the values they currently contain """ + """Return a string of the contained attributes, with the values they currently contain""" return "\n".join([p.documentation() for p in self._properties]) def __getattr__(self, attr): - """ Overwrite the attribute retrieval to be able to fetch the actual values from the information """ + """Overwrite the attribute retrieval to be able to fetch the actual values from the information""" inst = self._instance if attr not in self._attrs: - raise AttributeError(f"{inst.__class__.__name__}.info.{attr} does not exist, did you mistype?") + raise AttributeError( + f"{inst.__class__.__name__}.info.{attr} does not exist, did you mistype?" + ) idx = self._attrs.index(attr) prop = self._properties[idx] @@ -801,7 +829,7 @@ def __getattr__(self, attr): pass with inst: line = inst.readline() - while not (prop.found or line == ''): + while not (prop.found or line == ""): line = inst.readline() if loc is not None: inst.fh.seek(loc) @@ -814,7 +842,7 @@ def __getattr__(self, attr): return prop.value class InfoAttr: - """ Holder for parsing lines and extracting information from text files + """Holder for parsing lines and extracting information from text files This consists of: @@ -842,17 +870,19 @@ def parser(attr, match) found: whether the value has been found in the file. """ + __slots__ = ("attr", "regex", "parser", "updatable", "value", "found", "doc") - def __init__(self, - attr: str, - regex: Union[str, re.Pattern], - parser, - doc: str="", - updatable: bool=False, - default: Optional[Any]=None, - found: bool=False, - ): + def __init__( + self, + attr: str, + regex: Union[str, re.Pattern], + parser, + doc: str = "", + updatable: bool = False, + default: Optional[Any] = None, + found: bool = False, + ): self.attr = attr if isinstance(regex, str): regex = re.compile(regex) @@ -870,23 +900,25 @@ def process(self, line): match = self.regex.match(line) if match: self.value = self.parser(self, match) - #print(f"found {self.attr}={self.value} with {line}") + # print(f"found {self.attr}={self.value} with {line}") self.found = True return True return False def copy(self): - return self.__class__(attr=self.attr, - regex=self.regex, - parser=self.parser, - doc=self.doc, - updatable=self.updatable, - default=self.value, - found=self.found) + return self.__class__( + attr=self.attr, + regex=self.regex, + parser=self.parser, + doc=self.doc, + updatable=self.updatable, + default=self.value, + found=self.found, + ) def documentation(self): - """ Returns a documentation string for this object """ + """Returns a documentation string for this object""" if self.doc: doc = "\n" + indent(dedent(self.doc), " " * 4) else: @@ -900,7 +932,7 @@ def __init__(self, *args, **kwargs): @set_module("sisl.io") class Sile(Info, BaseSile): - """ Base class for ASCII files + """Base class for ASCII files All ASCII files that needs to be added to the global lookup table can with benefit inherit this class. @@ -913,7 +945,6 @@ class Sile(Info, BaseSile): """ def __new__(cls, filename, *args, **kwargs): - # check whether filename is an actual str, or StringIO or some buffer if not isinstance(filename, TextIOBase): # this is just a regular sile opening @@ -927,16 +958,21 @@ def __init_subclass__(cls, buffer_cls=None): return if buffer_cls is None: - buffer_cls = type(f"{cls.__name__}Buffer", (BufferSile, cls), - # Ensure the module is the same - {"__module__": cls.__module__}) + buffer_cls = type( + f"{cls.__name__}Buffer", + (BufferSile, cls), + # Ensure the module is the same + {"__module__": cls.__module__}, + ) elif not issubclass(buffer_cls, BufferSile): - raise TypeError(f"The passed buffer_cls should inherit from sisl.io.BufferSile to " - "ensure correct behaviour.") + raise TypeError( + f"The passed buffer_cls should inherit from sisl.io.BufferSile to " + "ensure correct behaviour." + ) cls._buffer_cls = buffer_cls - def __init__(self, filename, mode='r', *args, **kwargs): + def __init__(self, filename, mode="r", *args, **kwargs): super().__init__(*args, **kwargs) self._file = Path(filename) self._mode = mode @@ -963,9 +999,9 @@ def _open(self): self._fh_opens = 0 if self.file.suffix == ".gz": - if self._mode == 'r': + if self._mode == "r": # assume the file is a text file and open in text-mode - self.fh = gzip.open(str(self.file), mode='rt') + self.fh = gzip.open(str(self.file), mode="rt") else: # assume this is opening in binary or write mode self.fh = gzip.open(str(self.file), mode=self._mode) @@ -978,7 +1014,7 @@ def _open(self): self._fh_opens += 1 def __enter__(self): - """ Opens the output file and returns it self """ + """Opens the output file and returns it self""" self._open() return self @@ -993,24 +1029,24 @@ def close(self): if self._fh_opens <= 0: self._line = 0 self.fh.close() - delattr(self, 'fh') + delattr(self, "fh") self._fh_opens = 0 @staticmethod def is_keys(keys): - """ Returns true if ``not isinstance(keys, str)`` """ + """Returns true if ``not isinstance(keys, str)``""" return not isinstance(keys, str) @staticmethod def key2case(key, case): - """ Converts str/list of keywords to proper case """ + """Converts str/list of keywords to proper case""" if case: return key return key.lower() @staticmethod def keys2case(keys, case): - """ Converts str/list of keywords to proper case """ + """Converts str/list of keywords to proper case""" if case: return keys return [k.lower() for k in keys] @@ -1034,14 +1070,14 @@ def line_has_keys(line, keys, case=True): return found def __iter__(self): - """ Reading the entire content, without regarding comments """ + """Reading the entire content, without regarding comments""" l = self.readline(comment=True) while l: yield l l = self.readline(comment=True) def readline(self, comment=False): - r""" Reads the next line of the file """ + r"""Reads the next line of the file""" l = self.fh.readline() self._line += 1 if comment: @@ -1051,8 +1087,10 @@ def readline(self, comment=False): self._line += 1 return l - def step_to(self, keywords, case=True, allow_reread=True, ret_index=False, reopen=False): - r""" Steps the file-handle until the keyword(s) is found in the input + def step_to( + self, keywords, case=True, allow_reread=True, ret_index=False, reopen=False + ): + r"""Steps the file-handle until the keyword(s) is found in the input Parameters ---------- @@ -1095,11 +1133,11 @@ def step_to(self, keywords, case=True, allow_reread=True, ret_index=False, reope while not found: l = self.readline() - if l == '': + if l == "": break found = self.line_has_keys(l, keys, case) - if not found and (l == '' and line > 0) and allow_reread: + if not found and (l == "" and line > 0) and allow_reread: # We may be in the case where the user request # reading the same twice... # So we need to re-read the file... @@ -1110,7 +1148,7 @@ def step_to(self, keywords, case=True, allow_reread=True, ret_index=False, reope # Try and read again while not found and self._line <= line: l = self.readline() - if l == '': + if l == "": break found = self.line_has_keys(l, keys, case) @@ -1130,7 +1168,7 @@ def step_to(self, keywords, case=True, allow_reread=True, ret_index=False, reope return found, l def _write(self, *args, **kwargs): - """ Wrapper to default the write statements """ + """Wrapper to default the write statements""" self.fh.write(*args, **kwargs) @@ -1139,6 +1177,7 @@ def _write(self, *args, **kwargs): # a pass around it netCDF4 = None + def _import_netCDF4(): global netCDF4 if netCDF4 is None: @@ -1147,6 +1186,7 @@ def _import_netCDF4(): except ImportError as e: # append import sys + exe = Path(sys.executable).name msg = f"Could not import netCDF4. Please install it using '{exe} -m pip install netCDF4'" raise SileError(msg) from e @@ -1154,7 +1194,7 @@ def _import_netCDF4(): @set_module("sisl.io") class SileCDF(BaseSile): - """ Creates/Opens a SileCDF + """Creates/Opens a SileCDF Opens a SileCDF with `mode` and compression level `lvl`. If `mode` is in read-mode (r) the compression level @@ -1167,7 +1207,7 @@ class SileCDF(BaseSile): 1) means stores certain variables in the object. """ - def __init__(self, filename, mode='r', lvl=0, access=1, *args, **kwargs): + def __init__(self, filename, mode="r", lvl=0, access=1, *args, **kwargs): _import_netCDF4() self._file = Path(filename) @@ -1185,26 +1225,29 @@ def __init__(self, filename, mode='r', lvl=0, access=1, *args, **kwargs): self._access = 0 # The CDF file can easily open the file - if kwargs.pop('_open', True): - self.__dict__['fh'] = netCDF4.Dataset(str(self.file), self._mode, - format='NETCDF4') + if kwargs.pop("_open", True): + self.__dict__["fh"] = netCDF4.Dataset( + str(self.file), self._mode, format="NETCDF4" + ) # Must call setup-methods self._base_setup(*args, **kwargs) @property def _cmp_args(self): - """ Returns the compression arguments for the NetCDF file + """Returns the compression arguments for the NetCDF file >>> nc.createVariable(..., **self._cmp_args) """ - return {'zlib': self._lvl > 0, 'complevel': self._lvl} + return {"zlib": self._lvl > 0, "complevel": self._lvl} def __enter__(self): - """ Opens the output file and returns it self """ + """Opens the output file and returns it self""" # We do the import here - if 'fh' not in self.__dict__: - self.__dict__['fh'] = netCDF4.Dataset(str(self.file), self._mode, format='NETCDF4') + if "fh" not in self.__dict__: + self.__dict__["fh"] = netCDF4.Dataset( + str(self.file), self._mode, format="NETCDF4" + ) return self def __exit__(self, type, value, traceback): @@ -1215,12 +1258,12 @@ def close(self): self.fh.close() def _dimension(self, name, tree=None): - """ Local method for obtaing the dimension in a certain tree """ + """Local method for obtaing the dimension in a certain tree""" return self._dimensions(self, name, tree) @staticmethod def _dimensions(n, name, tree=None): - """ Retrieve method to get the NetCDF variable """ + """Retrieve method to get the NetCDF variable""" if tree is None: return n.dimensions[name] @@ -1234,7 +1277,7 @@ def _dimensions(n, name, tree=None): return g.dimensions[name] def _variable(self, name, tree=None): - """ Local method for obtaining the data from the SileCDF. + """Local method for obtaining the data from the SileCDF. This method returns the variable as-is. """ @@ -1244,7 +1287,7 @@ def _variable(self, name, tree=None): return self._variables(self, name, tree=tree) def _value(self, name, tree=None): - """ Local method for obtaining the data from the SileCDF. + """Local method for obtaining the data from the SileCDF. This method returns the value of the variable. """ @@ -1252,7 +1295,7 @@ def _value(self, name, tree=None): @staticmethod def _variables(n, name, tree=None): - """ Retrieve method to get the NetCDF variable """ + """Retrieve method to get the NetCDF variable""" if tree is None: return n.variables[name] @@ -1267,8 +1310,8 @@ def _variables(n, name, tree=None): @staticmethod def _crt_grp(n, name): - if '/' in name: # this is NetCDF, so / is fixed as seperator! - groups = name.split('/') + if "/" in name: # this is NetCDF, so / is fixed as seperator! + groups = name.split("/") grp = n for group in groups: if len(group) > 0: @@ -1290,8 +1333,8 @@ def _crt_var(n, name, *args, **kwargs): if name in n.variables: return n.variables[name] - if 'attrs' in kwargs: - attrs = kwargs.pop('attrs') + if "attrs" in kwargs: + attrs = kwargs.pop("attrs") else: attrs = None var = n.createVariable(name, *args, **kwargs) @@ -1302,7 +1345,7 @@ def _crt_var(n, name, *args, **kwargs): @classmethod def isDimension(cls, obj): - """ Return true if ``obj`` is an instance of the NetCDF4 ``Dimension`` type + """Return true if ``obj`` is an instance of the NetCDF4 ``Dimension`` type This is just a wrapper for ``isinstance(obj, netCDF4.Dimension)``. """ @@ -1310,7 +1353,7 @@ def isDimension(cls, obj): @classmethod def isVariable(cls, obj): - """ Return true if ``obj`` is an instance of the NetCDF4 ``Variable`` type + """Return true if ``obj`` is an instance of the NetCDF4 ``Variable`` type This is just a wrapper for ``isinstance(obj, netCDF4.Variable)``. """ @@ -1318,7 +1361,7 @@ def isVariable(cls, obj): @classmethod def isGroup(cls, obj): - """ Return true if ``obj`` is an instance of the NetCDF4 ``Group`` type + """Return true if ``obj`` is an instance of the NetCDF4 ``Group`` type This is just a wrapper for ``isinstance(obj, netCDF4.Group)``. """ @@ -1326,15 +1369,16 @@ def isGroup(cls, obj): @classmethod def isDataset(cls, obj): - """ Return true if ``obj`` is an instance of the NetCDF4 ``Dataset`` type + """Return true if ``obj`` is an instance of the NetCDF4 ``Dataset`` type This is just a wrapper for ``isinstance(obj, netCDF4.Dataset)``. """ return isinstance(obj, netCDF4.Dataset) + isRoot = isDataset def iter(self, group=True, dimension=True, variable=True, levels=-1, root=None): - """ Iterator on all groups, variables and dimensions. + """Iterator on all groups, variables and dimensions. This iterator iterates through all groups, variables and dimensions in the ``Dataset`` @@ -1395,57 +1439,66 @@ def iter(self, group=True, dimension=True, variable=True, levels=-1, root=None): return for grp in head.groups.values(): - yield from self.iter(group, dimension, variable, - levels=levels-1, root=grp.path) + yield from self.iter( + group, dimension, variable, levels=levels - 1, root=grp.path + ) __iter__ = iter @set_module("sisl.io") class SileBin(BaseSile): - """ Creates/Opens a SileBin + """Creates/Opens a SileBin Opens a SileBin with `mode` (b). If `mode` is in read-mode (r). """ - def __init__(self, filename, mode='r', *args, **kwargs): + def __init__(self, filename, mode="r", *args, **kwargs): self._file = Path(filename) # Open mode - self._mode = mode.replace('b', '') + 'b' + self._mode = mode.replace("b", "") + "b" # Must call setup-methods self._base_setup(*args, **kwargs) def __enter__(self): - """ Opens the output file and returns it self """ + """Opens the output file and returns it self""" return self def __exit__(self, type, value, traceback): return False -def sile_raise_write(self, ok=('w', 'a')): +def sile_raise_write(self, ok=("w", "a")): is_ok = False for O in ok: is_ok = is_ok or (O in self._mode) if not is_ok: - raise SileError(('Writing to file not possible; allowed ' - 'modes={}, used mode={}'.format(ok, self._mode)), self) + raise SileError( + ( + "Writing to file not possible; allowed " + "modes={}, used mode={}".format(ok, self._mode) + ), + self, + ) -def sile_raise_read(self, ok=('r', 'a')): +def sile_raise_read(self, ok=("r", "a")): is_ok = False for O in ok: is_ok = is_ok or (O in self._mode) if not is_ok: - raise SileError(f"Reading file not possible; allowed " - f"modes={ok}, used mode={self._mode}", self) + raise SileError( + f"Reading file not possible; allowed " + f"modes={ok}, used mode={self._mode}", + self, + ) @set_module("sisl.io") class SileError(IOError): - """ Define an error object related to the Sile objects """ + """Define an error object related to the Sile objects""" def __init__(self, value, obj=None): self.value = value @@ -1460,15 +1513,17 @@ def __str__(self): @set_module("sisl.io") class SileWarning(SislWarning): - """ Warnings that informs users of things to be carefull about when using their retrieved data + """Warnings that informs users of things to be carefull about when using their retrieved data These warnings should be issued whenever a read/write routine is unable to retrieve all information but are non-influential in the sense that sisl is still able to perform the action. """ + pass @set_module("sisl.io") class SileInfo(SislInfo): - """ Information for the user, this is hidden in a warning, but is not as severe so as to issue a warning. """ + """Information for the user, this is hidden in a warning, but is not as severe so as to issue a warning.""" + pass diff --git a/src/sisl/io/table.py b/src/sisl/io/table.py index 8cb66a51ce..74c39d88d5 100644 --- a/src/sisl/io/table.py +++ b/src/sisl/io/table.py @@ -10,12 +10,12 @@ from .sile import Sile, add_sile, sile_fh_open, sile_raise_write -__all__ = ['tableSile', 'TableSile'] +__all__ = ["tableSile", "TableSile"] @set_module("sisl.io") class tableSile(Sile): - """ ASCII tabular formatted data + """ASCII tabular formatted data A generic table data which will easily accommodate the most common write-outs of data @@ -71,13 +71,13 @@ class tableSile(Sile): """ def _setup(self, *args, **kwargs): - """ Setup the `tableSile` after initialization """ + """Setup the `tableSile` after initialization""" super()._setup(*args, **kwargs) - self._comment = ['#'] + self._comment = ["#"] @sile_fh_open() def write_data(self, *args, **kwargs): - """ Write tabular data to the file with optional header. + """Write tabular data to the file with optional header. Parameters ---------- @@ -120,31 +120,36 @@ def write_data(self, *args, **kwargs): """ sile_raise_write(self) - fmt = kwargs.get('fmt', '.5e') - newline = kwargs.get('newline', '\n') - delimiter = kwargs.get('delimiter', '\t') + fmt = kwargs.get("fmt", ".5e") + newline = kwargs.get("newline", "\n") + delimiter = kwargs.get("delimiter", "\t") _com = self._comment[0] - def comment_newline(line, prefix=''): - """ Converts a list of str arguments into nicely formatted commented - and newlined output """ + def comment_newline(line, prefix=""): + """Converts a list of str arguments into nicely formatted commented + and newlined output""" nonlocal _com line = map(lambda s: s.strip(), line.strip().split(newline)) # always append a newline - line = newline.join([s if s.startswith(_com) else f"{_com}{prefix}{s}" for s in line]) + newline + line = ( + newline.join( + [s if s.startswith(_com) else f"{_com}{prefix}{s}" for s in line] + ) + + newline + ) return line - comment = kwargs.get('comment', None) + comment = kwargs.get("comment", None) if comment is None: - comment = '' + comment = "" elif isinstance(comment, str): - comment = comment_newline(comment, ' ') + comment = comment_newline(comment, " ") else: - comment = comment_newline(newline.join(comment), ' ') + comment = comment_newline(newline.join(comment), " ") - header = kwargs.get('header', None) + header = kwargs.get("header", None) if header is None: - header = '' + header = "" elif isinstance(header, str): header = comment_newline(header) else: @@ -153,9 +158,9 @@ def comment_newline(line, prefix=''): # Finalize output header = comment + header - footer = kwargs.get('footer', None) + footer = kwargs.get("footer", None) if footer is None: - footer = '' + footer = "" elif isinstance(footer, str): pass else: @@ -182,7 +187,7 @@ def comment_newline(line, prefix=''): else: dat = np.vstack(args) - _fmt = '{:' + fmt + '}' + _fmt = "{:" + fmt + "}" # Reshape such that it becomes easy ndim = dat.ndim @@ -192,13 +197,17 @@ def comment_newline(line, prefix=''): dat.shape = (1, -1) if ndim > 2: - _fmt = kwargs.get('fmts', (_fmt + delimiter) * (dat.shape[1] - 1) + _fmt + newline) + _fmt = kwargs.get( + "fmts", (_fmt + delimiter) * (dat.shape[1] - 1) + _fmt + newline + ) for i in range(dat.shape[0]): for j in range(dat.shape[2]): self._write(_fmt.format(*dat[i, :, j])) self._write(newline * 2) else: - _fmt = kwargs.get('fmts', (_fmt + delimiter) * (dat.shape[0] - 1) + _fmt + newline) + _fmt = kwargs.get( + "fmts", (_fmt + delimiter) * (dat.shape[0] - 1) + _fmt + newline + ) for i in range(dat.shape[1]): self._write(_fmt.format(*dat[:, i])) @@ -207,7 +216,7 @@ def comment_newline(line, prefix=''): @sile_fh_open() def read_data(self, *args, **kwargs): - """ Read tabular data from the file. + """Read tabular data from the file. Parameters ---------- @@ -223,15 +232,15 @@ def read_data(self, *args, **kwargs): lines starting with this are discarded as comments """ # Override the comment in the file - self._comment = [kwargs.get('comment', self._comment[0])] + self._comment = [kwargs.get("comment", self._comment[0])] # Skip to next line comment = [] - header = '' + header = "" # Also read comments line = self.readline(True) - while line.startswith(self._comment[0] + ' '): + while line.startswith(self._comment[0] + " "): comment.append(line) line = self.readline(True) @@ -244,18 +253,20 @@ def read_data(self, *args, **kwargs): # First we need to figure out the separator: len_sep = 0 - sep = kwargs.get('delimiter', '') + sep = kwargs.get("delimiter", "") if len(sep) == 0: - for cur_sep in ['\t', ' ', ',']: + for cur_sep in ["\t", " ", ","]: s = line.split(cur_sep) if len(s) > len_sep: len_sep = len(s) sep = cur_sep if len(sep) == 0: - raise ValueError(self.__class__.__name__ + '.read_data could not determine ' - 'column separator...') + raise ValueError( + self.__class__.__name__ + ".read_data could not determine " + "column separator..." + ) - empty = re.compile(r'\s*\n') + empty = re.compile(r"\s*\n") while len(line) > 0: # If we start a line by a comment, or a newline # then we have a new data set @@ -288,8 +299,8 @@ def read_data(self, *args, **kwargs): if dat.ndim > 1: dat = np.swapaxes(dat, -2, -1) - ret_comment = kwargs.get('ret_comment', False) - ret_header = kwargs.get('ret_header', False) + ret_comment = kwargs.get("ret_comment", False) + ret_header = kwargs.get("ret_header", False) if ret_comment: if ret_header: return dat, comment, header @@ -304,5 +315,5 @@ def read_data(self, *args, **kwargs): TableSile = tableSile -add_sile('table', tableSile, case=False, gzip=True) -add_sile('dat', tableSile, case=False, gzip=True) +add_sile("table", tableSile, case=False, gzip=True) +add_sile("dat", tableSile, case=False, gzip=True) diff --git a/src/sisl/io/tbtrans/__init__.py b/src/sisl/io/tbtrans/__init__.py index c88f785f9c..40e418cc9d 100644 --- a/src/sisl/io/tbtrans/__init__.py +++ b/src/sisl/io/tbtrans/__init__.py @@ -64,7 +64,7 @@ - `phtprojncSilePHtrans` (projected PHtrans output) """ -from .sile import * # isort: split +from .sile import * # isort: split from .binaries import * from .delta import * from .pht import * diff --git a/src/sisl/io/tbtrans/_cdf.py b/src/sisl/io/tbtrans/_cdf.py index ee126be1c8..7e3d51c266 100644 --- a/src/sisl/io/tbtrans/_cdf.py +++ b/src/sisl/io/tbtrans/_cdf.py @@ -7,6 +7,7 @@ import numpy as np import sisl._array as _a + # Import the geometry object from sisl import Atom, Geometry, Lattice from sisl._indices import indices @@ -19,34 +20,35 @@ from ..sile import SileWarning from .sile import SileCDFTBtrans -__all__ = ['_ncSileTBtrans', '_devncSileTBtrans'] +__all__ = ["_ncSileTBtrans", "_devncSileTBtrans"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ry2eV = unit_convert('Ry', 'eV') -Ry2K = unit_convert('Ry', 'K') -eV2Ry = unit_convert('eV', 'Ry') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ry2eV = unit_convert("Ry", "eV") +Ry2K = unit_convert("Ry", "K") +eV2Ry = unit_convert("eV", "Ry") @set_module("sisl.io.tbtrans") class _ncSileTBtrans(SileCDFTBtrans): - r""" Common TBtrans NetCDF file object due to a lot of the files having common entries + r"""Common TBtrans NetCDF file object due to a lot of the files having common entries This enables easy read of the Geometry and Lattices etc. """ + @lru_cache(maxsize=1) def read_lattice(self): - """ Returns `Lattice` object from this file """ + """Returns `Lattice` object from this file""" cell = _a.arrayd(np.copy(self.cell)) cell.shape = (3, 3) - nsc = self._value('nsc') + nsc = self._value("nsc") lattice = Lattice(cell, nsc=nsc) - lattice.sc_off = self._value('isc_off') + lattice.sc_off = self._value("isc_off") return lattice def read_geometry(self, *args, **kwargs): - """ Returns `Geometry` object from this file """ + """Returns `Geometry` object from this file""" lattice = self.read_lattice() xyz = _a.arrayd(np.copy(self.xa)) @@ -57,20 +59,20 @@ def read_geometry(self, *args, **kwargs): nos = np.append([lasto[0]], np.diff(lasto)) nos = _a.arrayi(nos) - if 'atom' in kwargs: + if "atom" in kwargs: # The user "knows" which atoms are present - atms = kwargs['atom'] + atms = kwargs["atom"] # Check that all atoms have the correct number of orbitals. # Otherwise we will correct them for i in range(len(atms)): if atms[i].no != nos[i]: - atms[i] = Atom(atms[i].Z, [-1] *nos[i], tag=atms[i].tag) + atms[i] = Atom(atms[i].Z, [-1] * nos[i], tag=atms[i].tag) else: # Default to Hydrogen atom with nos[ia] orbitals # This may be counterintuitive but there is no storage of the # actual species - atms = [Atom('H', [-1] * o) for o in nos] + atms = [Atom("H", [-1] * o) for o in nos] # Create and return geometry object geom = Geometry(xyz, atms, lattice=lattice) @@ -83,79 +85,87 @@ def read_geometry(self, *args, **kwargs): @property @lru_cache(maxsize=1) def geometry(self): - """ The associated geometry from this file """ + """The associated geometry from this file""" return self.read_geometry() + geom = geometry @property @lru_cache(maxsize=1) def cell(self): - """ Unit cell in file """ - return self._value('cell') * Bohr2Ang + """Unit cell in file""" + return self._value("cell") * Bohr2Ang @property @lru_cache(maxsize=1) def na(self): - """ Returns number of atoms in the cell """ - return len(self._dimension('na_u')) + """Returns number of atoms in the cell""" + return len(self._dimension("na_u")) + na_u = na @property @lru_cache(maxsize=1) def no(self): - """ Returns number of orbitals in the cell """ - return len(self._dimension('no_u')) + """Returns number of orbitals in the cell""" + return len(self._dimension("no_u")) + no_u = no @property @lru_cache(maxsize=1) def xyz(self): - """ Atomic coordinates in file """ - return self._value('xa') * Bohr2Ang + """Atomic coordinates in file""" + return self._value("xa") * Bohr2Ang + xa = xyz @property @lru_cache(maxsize=1) def lasto(self): - """ Last orbital of corresponding atom """ - return self._value('lasto') - 1 + """Last orbital of corresponding atom""" + return self._value("lasto") - 1 @property @lru_cache(maxsize=1) def k(self): - """ Sampled k-points in file """ - return self._value('kpt') + """Sampled k-points in file""" + return self._value("kpt") + kpt = k @property @lru_cache(maxsize=1) def wk(self): - """ Weights of k-points in file """ - return self._value('wkpt') + """Weights of k-points in file""" + return self._value("wkpt") + wkpt = wk @property @lru_cache(maxsize=1) def nk(self): - """ Number of k-points in file """ - return len(self.dimensions['nkpt']) + """Number of k-points in file""" + return len(self.dimensions["nkpt"]) + nkpt = nk @property @lru_cache(maxsize=1) def E(self): - """ Sampled energy-points in file """ - return self._value('E') * Ry2eV + """Sampled energy-points in file""" + return self._value("E") * Ry2eV @property @lru_cache(maxsize=1) def ne(self): - """ Number of energy-points in file """ - return len(self._dimension('ne')) + """Number of energy-points in file""" + return len(self._dimension("ne")) + nE = ne def Eindex(self, E): - """ Return the closest energy index corresponding to the energy ``E`` + """Return the closest energy index corresponding to the energy ``E`` Parameters ---------- @@ -171,15 +181,19 @@ def Eindex(self, E): idxE = np.abs(self.E - E).argmin() ret_E = self.E[idxE] if abs(ret_E - E) > 5e-3: - warn(f"{self.__class__.__name__} requesting energy " - f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!") + warn( + f"{self.__class__.__name__} requesting energy " + f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!" + ) elif abs(ret_E - E) > 1e-3: - info(f"{self.__class__.__name__} requesting energy " - f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!") + info( + f"{self.__class__.__name__} requesting energy " + f"{E:.5f} eV, found {ret_E:.5f} eV as the closest energy!" + ) return idxE def _bias_window_integrator(self, elec_from=0, elec_to=1): - r""" An integrator for the bias window between two electrodes + r"""An integrator for the bias window between two electrodes Given two chemical potentials this returns an integrator (function) which returns weights for an input energy-point roughly equivalent to: @@ -209,12 +223,14 @@ def _bias_window_integrator(self, elec_from=0, elec_to=1): dE = E[1] - E[0] def integrator(E): - return dE * (fermi_dirac(E, kt_from, mu_from) - fermi_dirac(E, kt_to, mu_to)) + return dE * ( + fermi_dirac(E, kt_from, mu_from) - fermi_dirac(E, kt_to, mu_to) + ) return integrator def kindex(self, k): - """ Return the index of the k-point that is closests to the queried k-point (in reduced coordinates) + """Return the index of the k-point that is closests to the queried k-point (in reduced coordinates) Parameters ---------- @@ -227,33 +243,38 @@ def kindex(self, k): ik = np.sum(np.abs(self.k - _a.asarrayd(k)[None, :]), axis=1).argmin() ret_k = self.k[ik, :] if not np.allclose(ret_k, k, atol=0.0001): - warn(SileWarning(self.__class__.__name__ + " requesting k-point " + - "[{:.3f}, {:.3f}, {:.3f}]".format(*k) + - " found " + - "[{:.3f}, {:.3f}, {:.3f}]".format(*ret_k))) + warn( + SileWarning( + self.__class__.__name__ + + " requesting k-point " + + "[{:.3f}, {:.3f}, {:.3f}]".format(*k) + + " found " + + "[{:.3f}, {:.3f}, {:.3f}]".format(*ret_k) + ) + ) return ik @set_module("sisl.io.tbtrans") class _devncSileTBtrans(_ncSileTBtrans): - r""" Common TBtrans NetCDF file object due to a lot of the files having common entries + r"""Common TBtrans NetCDF file object due to a lot of the files having common entries This one also enables device region atoms and pivoting tables. """ def read_geometry(self, *args, **kwargs): - """ Returns `Geometry` object from this file """ + """Returns `Geometry` object from this file""" g = super().read_geometry(*args, **kwargs) try: - g['Buffer'] = self.a_buf[:] + g["Buffer"] = self.a_buf[:] except Exception: # Then no buffer atoms pass - g['Device'] = self.a_dev[:] + g["Device"] = self.a_dev[:] try: for elec in self.elecs: - g[elec] = self._value('a', [elec]) - 1 - g[f"{elec}+"] = self._value('a_down', [elec]) - 1 + g[elec] = self._value("a", [elec]) - 1 + g[f"{elec}+"] = self._value("a_down", [elec]) - 1 except Exception: pass return g @@ -261,43 +282,45 @@ def read_geometry(self, *args, **kwargs): @property @lru_cache(maxsize=1) def na_b(self): - """ Number of atoms in the buffer region """ - return len(self._dimension('na_b')) + """Number of atoms in the buffer region""" + return len(self._dimension("na_b")) + na_buffer = na_b @property @lru_cache(maxsize=1) def a_buf(self): - """ Atomic indices (0-based) of device atoms """ - return self._value('a_buf') - 1 + """Atomic indices (0-based) of device atoms""" + return self._value("a_buf") - 1 # Device atoms and other quantities @property @lru_cache(maxsize=1) def na_d(self): - """ Number of atoms in the device region """ - return len(self._dimension('na_d')) + """Number of atoms in the device region""" + return len(self._dimension("na_d")) + na_dev = na_d @property @lru_cache(maxsize=1) def a_dev(self): - """ Atomic indices (0-based) of device atoms (sorted) """ - return self._value('a_dev') - 1 + """Atomic indices (0-based) of device atoms (sorted)""" + return self._value("a_dev") - 1 @lru_cache(maxsize=16) def a_elec(self, elec): - """ Electrode atomic indices for the full geometry (sorted) + """Electrode atomic indices for the full geometry (sorted) Parameters ---------- elec : str or int electrode to retrieve indices for """ - return self._value('a', self._elec(elec)) - 1 + return self._value("a", self._elec(elec)) - 1 def a_down(self, elec, bulk=False): - """ Down-folding atomic indices for a given electrode + """Down-folding atomic indices for a given electrode Parameters ---------- @@ -309,12 +332,12 @@ def a_down(self, elec, bulk=False): """ if bulk: return self.a_elec(elec) - return self._value('a_down', self._elec(elec)) - 1 + return self._value("a_down", self._elec(elec)) - 1 @property @lru_cache(maxsize=1) def o_dev(self): - """ Orbital indices (0-based) of device orbitals (sorted) + """Orbital indices (0-based) of device orbitals (sorted) See Also -------- @@ -325,11 +348,11 @@ def o_dev(self): @property @lru_cache(maxsize=1) def no_d(self): - """ Number of orbitals in the device region """ - return len(self.dimensions['no_d']) + """Number of orbitals in the device region""" + return len(self.dimensions["no_d"]) def _elec(self, elec): - """ Converts a string or integer to the corresponding electrode name + """Converts a string or integer to the corresponding electrode name Parameters ---------- @@ -351,18 +374,19 @@ def _elec(self, elec): @property @lru_cache(maxsize=1) def elecs(self): - """ List of electrodes """ + """List of electrodes""" return list(self.groups.keys()) @lru_cache(maxsize=16) def chemical_potential(self, elec): - """ Return the chemical potential associated with the electrode `elec` """ - return self._value('mu', self._elec(elec))[0] * Ry2eV + """Return the chemical potential associated with the electrode `elec`""" + return self._value("mu", self._elec(elec))[0] * Ry2eV + mu = chemical_potential @lru_cache(maxsize=16) def eta(self, elec=None): - """ The imaginary part used when calculating the self-energies in eV (or for the device + """The imaginary part used when calculating the self-energies in eV (or for the device Parameters ---------- @@ -371,13 +395,13 @@ def eta(self, elec=None): region eta will be returned. """ try: - return self._value('eta', self._elec(elec))[0] * self._E2eV + return self._value("eta", self._elec(elec))[0] * self._E2eV except Exception: - return 0. # unknown! + return 0.0 # unknown! @lru_cache(maxsize=16) def electron_temperature(self, elec): - """ Electron bath temperature [Kelvin] + """Electron bath temperature [Kelvin] Parameters ---------- @@ -388,11 +412,11 @@ def electron_temperature(self, elec): -------- kT: bath temperature in [eV] """ - return self._value('kT', self._elec(elec))[0] * Ry2K + return self._value("kT", self._elec(elec))[0] * Ry2K @lru_cache(maxsize=16) def kT(self, elec): - """ Electron bath temperature [eV] + """Electron bath temperature [eV] Parameters ---------- @@ -403,22 +427,22 @@ def kT(self, elec): -------- electron_temperature: bath temperature in [K] """ - return self._value('kT', self._elec(elec))[0] * Ry2eV + return self._value("kT", self._elec(elec))[0] * Ry2eV @lru_cache(maxsize=16) def bloch(self, elec): - """ Bloch-expansion coefficients for an electrode + """Bloch-expansion coefficients for an electrode Parameters ---------- elec : str or int bloch expansions of electrode """ - return self._value('bloch', self._elec(elec)) + return self._value("bloch", self._elec(elec)) @lru_cache(maxsize=16) def n_btd(self, elec=None): - """ Number of blocks in the BTD partioning + """Number of blocks in the BTD partioning Parameters ---------- @@ -426,11 +450,11 @@ def n_btd(self, elec=None): if None the number of blocks in the device region BTD matrix. Else the number of BTD blocks in the electrode down-folding. """ - return len(self._dimension('n_btd', self._elec(elec))) + return len(self._dimension("n_btd", self._elec(elec))) @lru_cache(maxsize=16) def btd(self, elec=None): - """ Block-sizes for the BTD method in the device/electrode region + """Block-sizes for the BTD method in the device/electrode region Parameters ---------- @@ -438,55 +462,55 @@ def btd(self, elec=None): the BTD block sizes for the device (if none), otherwise the downfolding BTD block sizes for the electrode """ - return self._value('btd', self._elec(elec)) + return self._value("btd", self._elec(elec)) @lru_cache(maxsize=16) def na_down(self, elec): - """ Number of atoms in the downfolding region (without device downfolded region) + """Number of atoms in the downfolding region (without device downfolded region) Parameters ---------- elec : str or int Number of downfolding atoms for electrode `elec` """ - return len(self._dimension('na_down', self._elec(elec))) + return len(self._dimension("na_down", self._elec(elec))) @lru_cache(maxsize=16) def no_e(self, elec): - """ Number of orbitals in the downfolded region of the electrode in the device + """Number of orbitals in the downfolded region of the electrode in the device Parameters ---------- elec : str or int Specify the electrode to query number of downfolded orbitals """ - return len(self._dimension('no_e', self._elec(elec))) + return len(self._dimension("no_e", self._elec(elec))) @lru_cache(maxsize=16) def no_down(self, elec): - """ Number of orbitals in the downfolding region (plus device downfolded region) + """Number of orbitals in the downfolding region (plus device downfolded region) Parameters ---------- elec : str or int Number of downfolding orbitals for electrode `elec` """ - return len(self._dimension('no_down', self._elec(elec))) + return len(self._dimension("no_down", self._elec(elec))) @lru_cache(maxsize=16) def pivot_down(self, elec): - """ Pivoting orbitals for the downfolding region of a given electrode + """Pivoting orbitals for the downfolding region of a given electrode Parameters ---------- elec : str or int the corresponding electrode to get the pivoting indices for """ - return self._value('pivot_down', self._elec(elec)) - 1 + return self._value("pivot_down", self._elec(elec)) - 1 @lru_cache(maxsize=32) def pivot(self, elec=None, in_device=False, sort=False): - """ Return the pivoting indices for a specific electrode (in the device region) or the device + """Return the pivoting indices for a specific electrode (in the device region) or the device Parameters ---------- @@ -523,7 +547,7 @@ def pivot(self, elec=None, in_device=False, sort=False): if elec is None: if in_device and sort: return _a.arangei(self.no_d) - pvt = self._value('pivot') - 1 + pvt = self._value("pivot") - 1 if in_device: # Count number of elements that we need to subtract from each orbital subn = _a.onesi(self.no) @@ -534,7 +558,7 @@ def pivot(self, elec=None, in_device=False, sort=False): return pvt # Get electrode pivoting elements - se_pvt = self._value('pivot', tree=self._elec(elec)) - 1 + se_pvt = self._value("pivot", tree=self._elec(elec)) - 1 if sort: # Sort pivoting indices # Since we know that pvt is also sorted, then @@ -543,7 +567,7 @@ def pivot(self, elec=None, in_device=False, sort=False): se_pvt = np.sort(se_pvt) if in_device: - pvt = self._value('pivot') - 1 + pvt = self._value("pivot") - 1 if sort: pvt = np.sort(pvt) # translate to the device indices @@ -551,7 +575,7 @@ def pivot(self, elec=None, in_device=False, sort=False): return se_pvt def a2p(self, atoms): - """ Return the pivoting orbital indices (0-based) for the atoms, possibly on an electrode + """Return the pivoting orbital indices (0-based) for the atoms, possibly on an electrode This is equivalent to: @@ -567,7 +591,7 @@ def a2p(self, atoms): return self.o2p(self.geometry.a2o(atoms, True)) def o2p(self, orbitals, elec=None): - """ Return the pivoting indices (0-based) for the orbitals, possibly on an electrode + """Return the pivoting indices (0-based) for the orbitals, possibly on an electrode Will warn if an orbital requested is not in the device list of orbitals. @@ -584,6 +608,8 @@ def o2p(self, orbitals, elec=None): porb = np.in1d(self.pivot(elec), orbitals).nonzero()[0] d = len(orbitals) - len(porb) if d != 0: - warn(f"{self.__class__.__name__}.o2p requesting an orbital outside the device region, " - f"{d} orbitals will be removed from the returned list") + warn( + f"{self.__class__.__name__}.o2p requesting an orbital outside the device region, " + f"{d} orbitals will be removed from the returned list" + ) return porb diff --git a/src/sisl/io/tbtrans/binaries.py b/src/sisl/io/tbtrans/binaries.py index d1f9343801..aeee003a27 100644 --- a/src/sisl/io/tbtrans/binaries.py +++ b/src/sisl/io/tbtrans/binaries.py @@ -6,16 +6,20 @@ from ..sile import add_sile -__all__ = ['tbtgfSileTBtrans'] +__all__ = ["tbtgfSileTBtrans"] dic = {} try: - dic['__doc__'] = _gfSileSiesta.__doc__.replace(_gfSileSiesta.__name__, 'tbtgfSileTBtrans') + dic["__doc__"] = _gfSileSiesta.__doc__.replace( + _gfSileSiesta.__name__, "tbtgfSileTBtrans" + ) except Exception: pass -tbtgfSileTBtrans = set_module("sisl.io.tbtrans")(type("tbtgfSileTBtrans", (_gfSileSiesta, ), dic)) +tbtgfSileTBtrans = set_module("sisl.io.tbtrans")( + type("tbtgfSileTBtrans", (_gfSileSiesta,), dic) +) del dic -add_sile('TBTGF', tbtgfSileTBtrans) +add_sile("TBTGF", tbtgfSileTBtrans) diff --git a/src/sisl/io/tbtrans/delta.py b/src/sisl/io/tbtrans/delta.py index e7387ff59a..e129f22508 100644 --- a/src/sisl/io/tbtrans/delta.py +++ b/src/sisl/io/tbtrans/delta.py @@ -6,8 +6,10 @@ import numpy as np import sisl._array as _a + # Import the geometry object from sisl import Atom, Geometry, Lattice, SparseOrbitalBZSpin + # Import sile objects from sisl._internal import set_module from sisl.messages import deprecate_argument, warn @@ -27,18 +29,18 @@ has_fortran_module = False -__all__ = ['deltancSileTBtrans'] +__all__ = ["deltancSileTBtrans"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ry2eV = unit_convert('Ry', 'eV') -eV2Ry = unit_convert('eV', 'Ry') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ry2eV = unit_convert("Ry", "eV") +eV2Ry = unit_convert("eV", "Ry") # The delta nc file @set_module("sisl.io.tbtrans") class deltancSileTBtrans(SileCDFTBtrans): - r""" TBtrans :math:`\delta` file object + r"""TBtrans :math:`\delta` file object The :math:`\delta` file object is an extension enabled in `TBtrans`_ which allows changing the Hamiltonian in transport problems. @@ -71,7 +73,7 @@ class deltancSileTBtrans(SileCDFTBtrans): @classmethod def merge(cls, fname, *deltas, **kwargs): - """ Merge several delta files into one Sile which contains the sum of the content + """Merge several delta files into one Sile which contains the sum of the content In cases where implementors use several different delta files it is necessary to merge them into a single delta file before use in TBtrans. @@ -95,23 +97,29 @@ def merge(cls, fname, *deltas, **kwargs): deltas_obj = [] for delta in deltas: if isinstance(delta, (str, Path)): - delta = cls(delta, mode='r') + delta = cls(delta, mode="r") deltas_obj.append(delta) if delta.__class__ != cls: - raise ValueError(f"{cls.__name__}.merge requires all files to be the same class.") + raise ValueError( + f"{cls.__name__}.merge requires all files to be the same class." + ) if delta.file == file: - raise ValueError(f"{cls.__name__}.merge requires that the output file is different from all arguments.") + raise ValueError( + f"{cls.__name__}.merge requires that the output file is different from all arguments." + ) # be sure to overwrite the input with objects deltas = deltas_obj - out = cls(fname, mode='w', **kwargs) + out = cls(fname, mode="w", **kwargs) # Now create and simultaneously check for the same arguments geom = deltas[0].read_geometry() for delta in deltas[1:]: if not geom.equal(delta.read_geometry()): - raise ValueError(f"{cls.__name__}.merge requires that the input files all contain the same geometry.") + raise ValueError( + f"{cls.__name__}.merge requires that the input files all contain the same geometry." + ) # Now we are ready to write out.write_geometry(geom) @@ -192,7 +200,7 @@ def merge(cls, fname, *deltas, **kwargs): out.write_delta(m, E=E, k=k) def has_level(self, ilvl): - """ Query whether the file has level `ilvl` content + """Query whether the file has level `ilvl` content Parameters ---------- @@ -202,14 +210,14 @@ def has_level(self, ilvl): return f"LEVEL-{ilvl}" in self.groups def read_lattice(self): - """ Returns the `Lattice` object from this file """ - cell = _a.arrayd(np.copy(self._value('cell'))) * Bohr2Ang + """Returns the `Lattice` object from this file""" + cell = _a.arrayd(np.copy(self._value("cell"))) * Bohr2Ang cell.shape = (3, 3) - nsc = self._value('nsc') + nsc = self._value("nsc") lattice = Lattice(cell, nsc=nsc) try: - lattice.sc_off = self._value('isc_off') + lattice.sc_off = self._value("isc_off") except Exception: # This is ok, we simply do not have the supercell offsets pass @@ -217,20 +225,20 @@ def read_lattice(self): return lattice def read_geometry(self, *args, **kwargs): - """ Returns the `Geometry` object from this file """ + """Returns the `Geometry` object from this file""" lattice = self.read_lattice() - xyz = _a.arrayd(np.copy(self._value('xa'))) * Bohr2Ang + xyz = _a.arrayd(np.copy(self._value("xa"))) * Bohr2Ang xyz.shape = (-1, 3) # Create list with correct number of orbitals - lasto = _a.arrayi(np.copy(self._value('lasto'))) + lasto = _a.arrayi(np.copy(self._value("lasto"))) nos = np.append([lasto[0]], np.diff(lasto)) nos = _a.arrayi(nos) - if 'atom' in kwargs: + if "atom" in kwargs: # The user "knows" which atoms are present - atms = kwargs['atom'] + atms = kwargs["atom"] # Check that all atoms have the correct number of orbitals. # Otherwise we will correct them for i in range(len(atms)): @@ -241,60 +249,62 @@ def read_geometry(self, *args, **kwargs): # Default to Hydrogen atom with nos[ia] orbitals # This may be counterintuitive but there is no storage of the # actual species - atms = [Atom('H', [-1] * o) for o in nos] + atms = [Atom("H", [-1] * o) for o in nos] # Create and return geometry object geom = Geometry(xyz, atms, lattice=lattice) return geom - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def write_lattice(self, lattice): - """ Creates the NetCDF file and writes the supercell information """ + """Creates the NetCDF file and writes the supercell information""" sile_raise_write(self) # Create initial dimensions - self._crt_dim(self, 'one', 1) - self._crt_dim(self, 'n_s', np.prod(lattice.nsc)) - self._crt_dim(self, 'xyz', 3) + self._crt_dim(self, "one", 1) + self._crt_dim(self, "n_s", np.prod(lattice.nsc)) + self._crt_dim(self, "xyz", 3) # Create initial geometry - v = self._crt_var(self, 'nsc', 'i4', ('xyz',)) - v.info = 'Number of supercells in each unit-cell direction' + v = self._crt_var(self, "nsc", "i4", ("xyz",)) + v.info = "Number of supercells in each unit-cell direction" v[:] = lattice.nsc[:] - v = self._crt_var(self, 'isc_off', 'i4', ('n_s', 'xyz')) + v = self._crt_var(self, "isc_off", "i4", ("n_s", "xyz")) v.info = "Index of supercell coordinates" v[:] = lattice.sc_off[:, :] - v = self._crt_var(self, 'cell', 'f8', ('xyz', 'xyz')) - v.info = 'Unit cell' - v.unit = 'Bohr' + v = self._crt_var(self, "cell", "f8", ("xyz", "xyz")) + v.info = "Unit cell" + v.unit = "Bohr" v[:] = lattice.cell[:, :] / Bohr2Ang # Create designation of the creation - self.method = 'sisl' + self.method = "sisl" def write_geometry(self, geometry): - """ Creates the NetCDF file and writes the geometry information """ + """Creates the NetCDF file and writes the geometry information""" sile_raise_write(self) # Create initial dimensions self.write_lattice(geometry.lattice) - self._crt_dim(self, 'no_s', np.prod(geometry.nsc) * geometry.no) - self._crt_dim(self, 'no_u', geometry.no) - self._crt_dim(self, 'na_u', geometry.na) + self._crt_dim(self, "no_s", np.prod(geometry.nsc) * geometry.no) + self._crt_dim(self, "no_u", geometry.no) + self._crt_dim(self, "na_u", geometry.na) # Create initial geometry - v = self._crt_var(self, 'lasto', 'i4', ('na_u',)) - v.info = 'Last orbital of equivalent atom' - v = self._crt_var(self, 'xa', 'f8', ('na_u', 'xyz')) - v.info = 'Atomic coordinates' - v.unit = 'Bohr' + v = self._crt_var(self, "lasto", "i4", ("na_u",)) + v.info = "Last orbital of equivalent atom" + v = self._crt_var(self, "xa", "f8", ("na_u", "xyz")) + v.info = "Atomic coordinates" + v.unit = "Bohr" # Save stuff - self.variables['xa'][:] = geometry.xyz / Bohr2Ang + self.variables["xa"][:] = geometry.xyz / Bohr2Ang - bs = self._crt_grp(self, 'BASIS') - b = self._crt_var(bs, 'basis', 'i4', ('na_u',)) + bs = self._crt_grp(self, "BASIS") + b = self._crt_var(bs, "basis", "i4", ("na_u",)) b.info = "Basis of each atom by ID" orbs = _a.emptyi([geometry.na]) @@ -305,9 +315,11 @@ def write_geometry(self, geometry): if a.tag in bs.groups: # Assert the file sizes if bs.groups[a.tag].Number_of_orbitals != a.no: - raise ValueError(f"File {self.file} " - "has erroneous data in regards of " - "of the alreay stored dimensions.") + raise ValueError( + f"File {self.file} " + "has erroneous data in regards of " + "of the alreay stored dimensions." + ) else: ba = bs.createGroup(a.tag) ba.ID = np.int32(isp + 1) @@ -318,18 +330,18 @@ def write_geometry(self, geometry): ba.Number_of_orbitals = np.int32(a.no) # Store the lasto variable as the remaining thing to do - self.variables['lasto'][:] = _a.cumsumi(orbs) + self.variables["lasto"][:] = _a.cumsumi(orbs) def _get_lvl_k_E(self, **kwargs): - """ Return level, k and E indices, in that order. + """Return level, k and E indices, in that order. The indices are negative if a new index needs to be created. """ # Determine the type of dH we are storing... - k = kwargs.get('k', None) + k = kwargs.get("k", None) if k is not None: k = _a.asarrayd(k).flatten() - E = kwargs.get('E', None) + E = kwargs.get("E", None) if (k is None) and (E is None): ilvl = 1 @@ -352,16 +364,16 @@ def _get_lvl_k_E(self, **kwargs): # Now determine the energy and k-indices iE = -1 if ilvl in (3, 4): - if lvl.variables['E'].size != 0: - Es = _a.arrayd(lvl.variables['E'][:]) + if lvl.variables["E"].size != 0: + Es = _a.arrayd(lvl.variables["E"][:]) iE = np.argmin(np.abs(Es - E)) if abs(Es[iE] - E) > 0.0001: iE = -1 ik = -1 if ilvl in (2, 4): - if lvl.variables['kpt'].size != 0: - kpt = _a.arrayd(lvl.variables['kpt'][:]) + if lvl.variables["kpt"].size != 0: + kpt = _a.arrayd(lvl.variables["kpt"][:]) kpt.shape = (-1, 3) ik = np.argmin(np.abs(kpt - k[None, :]).sum(axis=1)) if not np.allclose(kpt[ik, :], k, atol=0.0001): @@ -375,27 +387,35 @@ def _get_lvl(self, ilvl): raise ValueError(f"Level {ilvl} does not exist in {self.file}.") def _add_lvl(self, ilvl): - """ Simply adds and returns a group if it does not exist it will be created """ - slvl = f'LEVEL-{ilvl}' + """Simply adds and returns a group if it does not exist it will be created""" + slvl = f"LEVEL-{ilvl}" if slvl in self.groups: lvl = self._crt_grp(self, slvl) else: lvl = self._crt_grp(self, slvl) if ilvl in (2, 4): - self._crt_dim(lvl, 'nkpt', None) - self._crt_var(lvl, 'kpt', 'f8', ('nkpt', 'xyz'), - attrs={'info': 'k-points for delta values', - 'unit': 'b**-1'}) + self._crt_dim(lvl, "nkpt", None) + self._crt_var( + lvl, + "kpt", + "f8", + ("nkpt", "xyz"), + attrs={"info": "k-points for delta values", "unit": "b**-1"}, + ) if ilvl in (3, 4): - self._crt_dim(lvl, 'ne', None) - self._crt_var(lvl, 'E', 'f8', ('ne',), - attrs={'info': 'Energy points for delta values', - 'unit': 'Ry'}) + self._crt_dim(lvl, "ne", None) + self._crt_var( + lvl, + "E", + "f8", + ("ne",), + attrs={"info": "Energy points for delta values", "unit": "Ry"}, + ) return lvl def write_delta(self, delta, **kwargs): - r""" Writes a :math:`\delta` Hamiltonian to the file + r"""Writes a :math:`\delta` Hamiltonian to the file This term may be of @@ -417,7 +437,9 @@ def write_delta(self, delta, **kwargs): """ csr = delta._csr.copy() if csr.nnz == 0: - raise SileError(f"{self!s}.write_overlap cannot write a zero element sparse matrix!") + raise SileError( + f"{self!s}.write_overlap cannot write a zero element sparse matrix!" + ) # convert to siesta thing and store _csr_to_siesta(delta.geometry, csr, diag=False) @@ -428,40 +450,57 @@ def write_delta(self, delta, **kwargs): # Ensure that the geometry is written self.write_geometry(delta.geometry) - self._crt_dim(self, 'spin', len(delta.spin)) + self._crt_dim(self, "spin", len(delta.spin)) # Determine the type of delta we are storing... - k = kwargs.get('k', None) - E = kwargs.get('E', None) + k = kwargs.get("k", None) + E = kwargs.get("E", None) ilvl, ik, iE = self._get_lvl_k_E(**kwargs) lvl = self._add_lvl(ilvl) # Append the sparsity pattern # Create basis group - if 'n_col' in lvl.variables: - if len(lvl.dimensions['nnzs']) != csr.nnz: - raise ValueError("The sparsity pattern stored in delta *MUST* be equivalent for " - "all delta entries [nnz].") - if np.any(lvl.variables['n_col'][:] != csr.ncol[:]): - raise ValueError("The sparsity pattern stored in delta *MUST* be equivalent for " - "all delta entries [n_col].") - if np.any(lvl.variables['list_col'][:] != csr.col[:]+1): - raise ValueError("The sparsity pattern stored in delta *MUST* be equivalent for " - "all delta entries [list_col].") - if np.any(lvl.variables['isc_off'][:] != siesta_sc_off(*delta.geometry.lattice.nsc).T): - raise ValueError("The sparsity pattern stored in delta *MUST* be equivalent for " - "all delta entries [sc_off].") + if "n_col" in lvl.variables: + if len(lvl.dimensions["nnzs"]) != csr.nnz: + raise ValueError( + "The sparsity pattern stored in delta *MUST* be equivalent for " + "all delta entries [nnz]." + ) + if np.any(lvl.variables["n_col"][:] != csr.ncol[:]): + raise ValueError( + "The sparsity pattern stored in delta *MUST* be equivalent for " + "all delta entries [n_col]." + ) + if np.any(lvl.variables["list_col"][:] != csr.col[:] + 1): + raise ValueError( + "The sparsity pattern stored in delta *MUST* be equivalent for " + "all delta entries [list_col]." + ) + if np.any( + lvl.variables["isc_off"][:] + != siesta_sc_off(*delta.geometry.lattice.nsc).T + ): + raise ValueError( + "The sparsity pattern stored in delta *MUST* be equivalent for " + "all delta entries [sc_off]." + ) else: - self._crt_dim(lvl, 'nnzs', csr.nnz) - v = self._crt_var(lvl, 'n_col', 'i4', ('no_u',)) + self._crt_dim(lvl, "nnzs", csr.nnz) + v = self._crt_var(lvl, "n_col", "i4", ("no_u",)) v.info = "Number of non-zero elements per row" v[:] = csr.ncol[:] - v = self._crt_var(lvl, 'list_col', 'i4', ('nnzs',), - chunksizes=(csr.nnz,), **self._cmp_args) + v = self._crt_var( + lvl, + "list_col", + "i4", + ("nnzs",), + chunksizes=(csr.nnz,), + **self._cmp_args, + ) v.info = "Supercell column indices in the sparse format" v[:] = csr.col[:] + 1 # correct for fortran indices - v = self._crt_var(lvl, 'isc_off', 'i4', ('n_s', 'xyz')) + v = self._crt_var(lvl, "isc_off", "i4", ("n_s", "xyz")) v.info = "Index of supercell coordinates" v[:] = siesta_sc_off(*delta.geometry.lattice.nsc).T @@ -469,15 +508,15 @@ def write_delta(self, delta, **kwargs): if ilvl in (3, 4): if iE < 0: # We need to add the new value - iE = lvl.variables['E'].shape[0] - lvl.variables['E'][iE] = E * eV2Ry + iE = lvl.variables["E"].shape[0] + lvl.variables["E"][iE] = E * eV2Ry warn_E = False warn_k = True if ilvl in (2, 4): if ik < 0: - ik = lvl.variables['kpt'].shape[0] - lvl.variables['kpt'][ik, :] = k + ik = lvl.variables["kpt"].shape[0] + lvl.variables["kpt"][ik, :] = k warn_k = False if ilvl == 4 and warn_k and warn_E and False: @@ -485,7 +524,7 @@ def write_delta(self, delta, **kwargs): # point, this warning will proceed... # I.e. even though the variable has not been set, it will WARN # Hence we out-comment this for now... - #warn(f"Overwriting k-point {ik} and energy point {iE} correction.") + # warn(f"Overwriting k-point {ik} and energy point {iE} correction.") pass elif ilvl == 3 and warn_E: warn(f"Overwriting energy point {iE} correction.") @@ -493,21 +532,21 @@ def write_delta(self, delta, **kwargs): warn(f"Overwriting k-point {ik} correction.") if ilvl == 1: - dim = ('spin', 'nnzs') + dim = ("spin", "nnzs") sl = [slice(None)] * 2 csize = [1] * 2 elif ilvl == 2: - dim = ('nkpt', 'spin', 'nnzs') + dim = ("nkpt", "spin", "nnzs") sl = [slice(None)] * 3 sl[0] = ik csize = [1] * 3 elif ilvl == 3: - dim = ('ne', 'spin', 'nnzs') + dim = ("ne", "spin", "nnzs") sl = [slice(None)] * 3 sl[0] = iE csize = [1] * 3 elif ilvl == 4: - dim = ('nkpt', 'ne', 'spin', 'nnzs') + dim = ("nkpt", "ne", "spin", "nnzs") sl = [slice(None)] * 4 sl[0] = ik sl[1] = iE @@ -518,33 +557,50 @@ def write_delta(self, delta, **kwargs): if delta.spin.kind > delta.spin.POLARIZED: print(delta.spin) - raise ValueError(f"{self.__class__.__name__}.write_delta only allows spin-polarized delta values") - - if delta.dtype.kind == 'c': - v1 = self._crt_var(lvl, 'Redelta', 'f8', dim, - chunksizes=csize, - attrs={'info': "Real part of delta", - 'unit': "Ry"}, **self._cmp_args) - v2 = self._crt_var(lvl, 'Imdelta', 'f8', dim, - chunksizes=csize, - attrs={'info': "Imaginary part of delta", - 'unit': "Ry"}, **self._cmp_args) + raise ValueError( + f"{self.__class__.__name__}.write_delta only allows spin-polarized delta values" + ) + + if delta.dtype.kind == "c": + v1 = self._crt_var( + lvl, + "Redelta", + "f8", + dim, + chunksizes=csize, + attrs={"info": "Real part of delta", "unit": "Ry"}, + **self._cmp_args, + ) + v2 = self._crt_var( + lvl, + "Imdelta", + "f8", + dim, + chunksizes=csize, + attrs={"info": "Imaginary part of delta", "unit": "Ry"}, + **self._cmp_args, + ) for i in range(len(delta.spin)): sl[-2] = i v1[sl] = csr._D[:, i].real * eV2Ry v2[sl] = csr._D[:, i].imag * eV2Ry else: - v = self._crt_var(lvl, 'delta', 'f8', dim, - chunksizes=csize, - attrs={'info': "delta", - 'unit': "Ry"}, **self._cmp_args) + v = self._crt_var( + lvl, + "delta", + "f8", + dim, + chunksizes=csize, + attrs={"info": "delta", "unit": "Ry"}, + **self._cmp_args, + ) for i in range(len(delta.spin)): sl[-2] = i v[sl] = csr._D[:, i] * eV2Ry def _read_class(self, cls, **kwargs): - """ Reads a class model from a file """ + """Reads a class model from a file""" # Ensure that the geometry is written geom = self.read_geometry() @@ -556,7 +612,7 @@ def _read_class(self, cls, **kwargs): lvl = self._get_lvl(ilvl) if iE < 0 and ilvl in (3, 4): - E = kwargs.get('E', None) + E = kwargs.get("E", None) raise ValueError(f"Energy {E} eV does not exist in the file.") if ik < 0 and ilvl in (2, 4): raise ValueError("k-point requested does not exist in the file.") @@ -575,25 +631,25 @@ def _read_class(self, cls, **kwargs): sl[1] = iE # Now figure out what data-type the delta is. - if 'Redelta' in lvl.variables: + if "Redelta" in lvl.variables: # It *must* be a complex valued Hamiltonian is_complex = True dtype = np.complex128 - elif 'delta' in lvl.variables: + elif "delta" in lvl.variables: is_complex = False dtype = np.float64 # Get number of spins - nspin = len(self.dimensions['spin']) + nspin = len(self.dimensions["spin"]) # Now create the sparse matrix stuff (we re-create the # array, hence just allocate the smallest amount possible) C = cls(geom, nspin, nnzpr=1, dtype=dtype, orthogonal=True) - C._csr.ncol = _a.arrayi(lvl.variables['n_col'][:]) + C._csr.ncol = _a.arrayi(lvl.variables["n_col"][:]) # Update maximum number of connections (in case future stuff happens) C._csr.ptr = _ncol_to_indptr(C._csr.ncol) - C._csr.col = _a.arrayi(lvl.variables['list_col'][:]) - 1 + C._csr.col = _a.arrayi(lvl.variables["list_col"][:]) - 1 # Copy information over C._csr._nnz = len(C._csr.col) @@ -601,24 +657,24 @@ def _read_class(self, cls, **kwargs): if is_complex: for ispin in range(nspin): sl[-2] = ispin - C._csr._D[:, ispin].real = lvl.variables['Redelta'][sl] * Ry2eV - C._csr._D[:, ispin].imag = lvl.variables['Imdelta'][sl] * Ry2eV + C._csr._D[:, ispin].real = lvl.variables["Redelta"][sl] * Ry2eV + C._csr._D[:, ispin].imag = lvl.variables["Imdelta"][sl] * Ry2eV else: for ispin in range(nspin): sl[-2] = ispin - C._csr._D[:, ispin] = lvl.variables['delta'][sl] * Ry2eV + C._csr._D[:, ispin] = lvl.variables["delta"][sl] * Ry2eV # Convert from isc to sisl isc - _csr_from_sc_off(C.geometry, lvl.variables['isc_off'][:, :], C._csr) + _csr_from_sc_off(C.geometry, lvl.variables["isc_off"][:, :], C._csr) _mat_spin_convert(C) return C def read_delta(self, **kwargs): - """ Reads a delta model from the file """ + """Reads a delta model from the file""" return self._read_class(SparseOrbitalBZSpin, **kwargs) -add_sile('delta.nc', deltancSileTBtrans) -add_sile('dH.nc', deltancSileTBtrans) -add_sile('dSE.nc', deltancSileTBtrans) +add_sile("delta.nc", deltancSileTBtrans) +add_sile("dH.nc", deltancSileTBtrans) +add_sile("dSE.nc", deltancSileTBtrans) diff --git a/src/sisl/io/tbtrans/pht.py b/src/sisl/io/tbtrans/pht.py index 516b7cf9e2..9abaf5be50 100644 --- a/src/sisl/io/tbtrans/pht.py +++ b/src/sisl/io/tbtrans/pht.py @@ -6,45 +6,53 @@ from ..sile import add_sile from .tbt import Ry2eV, Ry2K, tbtavncSileTBtrans, tbtncSileTBtrans -__all__ = ['phtncSilePHtrans', 'phtavncSilePHtrans'] +__all__ = ["phtncSilePHtrans", "phtavncSilePHtrans"] @set_module("sisl.io.phtrans") class phtncSilePHtrans(tbtncSileTBtrans): - """ PHtrans file object """ - _trans_type = 'PHT' - _E2eV = Ry2eV ** 2 + """PHtrans file object""" + + _trans_type = "PHT" + _E2eV = Ry2eV**2 def phonon_temperature(self, elec): - """ Phonon bath temperature [Kelvin] """ - return self._value('kT', self._elec(elec))[0] * Ry2K + """Phonon bath temperature [Kelvin]""" + return self._value("kT", self._elec(elec))[0] * Ry2K def kT(self, elec): - """ Phonon bath temperature [eV] """ - return self._value('kT', self._elec(elec))[0] * Ry2eV + """Phonon bath temperature [eV]""" + return self._value("kT", self._elec(elec))[0] * Ry2eV @set_module("sisl.io.phtrans") class phtavncSilePHtrans(tbtavncSileTBtrans): - """ PHtrans file object """ - _trans_type = 'PHT' - _E2eV = Ry2eV ** 2 + """PHtrans file object""" + + _trans_type = "PHT" + _E2eV = Ry2eV**2 def phonon_temperature(self, elec): - """ Phonon bath temperature [Kelvin] """ - return self._value('kT', self._elec(elec))[0] * Ry2K + """Phonon bath temperature [Kelvin]""" + return self._value("kT", self._elec(elec))[0] * Ry2K def kT(self, elec): - """ Phonon bath temperature [eV] """ - return self._value('kT', self._elec(elec))[0] * Ry2eV - - -for _name in ['chemical_potential', 'electron_temperature', 'kT', - 'current', 'current_parameter', - 'shot_noise', 'noise_power']: + """Phonon bath temperature [eV]""" + return self._value("kT", self._elec(elec))[0] * Ry2eV + + +for _name in [ + "chemical_potential", + "electron_temperature", + "kT", + "current", + "current_parameter", + "shot_noise", + "noise_power", +]: setattr(phtncSilePHtrans, _name, None) setattr(phtavncSilePHtrans, _name, None) -add_sile('PHT.nc', phtncSilePHtrans) -add_sile('PHT.AV.nc', phtavncSilePHtrans) +add_sile("PHT.nc", phtncSilePHtrans) +add_sile("PHT.AV.nc", phtavncSilePHtrans) diff --git a/src/sisl/io/tbtrans/phtproj.py b/src/sisl/io/tbtrans/phtproj.py index 2ac699f14a..9e5d1b3cd5 100644 --- a/src/sisl/io/tbtrans/phtproj.py +++ b/src/sisl/io/tbtrans/phtproj.py @@ -7,14 +7,15 @@ from .tbt import Ry2eV from .tbtproj import tbtprojncSileTBtrans -__all__ = ['phtprojncSilePHtrans'] +__all__ = ["phtprojncSilePHtrans"] @set_module("sisl.io.phtrans") class phtprojncSilePHtrans(tbtprojncSileTBtrans): - """ PHtrans projection file object """ - _trans_type = 'PHT.Proj' - _E2eV = Ry2eV ** 2 + """PHtrans projection file object""" + _trans_type = "PHT.Proj" + _E2eV = Ry2eV**2 -add_sile('PHT.Proj.nc', phtprojncSilePHtrans) + +add_sile("PHT.Proj.nc", phtprojncSilePHtrans) diff --git a/src/sisl/io/tbtrans/se.py b/src/sisl/io/tbtrans/se.py index 65c4b77901..14eb63bf8c 100644 --- a/src/sisl/io/tbtrans/se.py +++ b/src/sisl/io/tbtrans/se.py @@ -10,6 +10,7 @@ from sisl._indices import indices from sisl._internal import set_module + # Import the geometry object from sisl.unit.siesta import unit_convert from sisl.utils import default_ArgumentParser, default_namespace @@ -17,16 +18,16 @@ from ..sile import add_sile from ._cdf import _devncSileTBtrans -__all__ = ['tbtsencSileTBtrans', 'phtsencSilePHtrans'] +__all__ = ["tbtsencSileTBtrans", "phtsencSilePHtrans"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ry2eV = unit_convert('Ry', 'eV') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ry2eV = unit_convert("Ry", "eV") @set_module("sisl.io.tbtrans") class tbtsencSileTBtrans(_devncSileTBtrans): - r""" TBtrans self-energy file object with downfolded self-energies to the device region + r"""TBtrans self-energy file object with downfolded self-energies to the device region The :math:`\Sigma` object contains all self-energies on the specified k- and energy grid projected into the device region. @@ -61,11 +62,11 @@ class tbtsencSileTBtrans(_devncSileTBtrans): >>> np.allclose(Hdev_pvt, Hdev[pvt_dev, pvt_dev.T]) True """ - _trans_type = 'TBT' + _trans_type = "TBT" _E2eV = Ry2eV def self_energy(self, elec, E, k=0, sort=False): - """ Return the self-energy from the electrode `elec` + """Return the self-energy from the electrode `elec` Parameters ---------- @@ -89,8 +90,8 @@ def self_energy(self, elec, E, k=0, sort=False): # When storing fortran arrays in C-type files reading it in # C-codes will transpose the data. # So we have to transpose back to get the correct order - re = self._variable('ReSelfEnergy', tree=tree)[ik, iE].T - im = self._variable('ImSelfEnergy', tree=tree)[ik, iE].T + re = self._variable("ReSelfEnergy", tree=tree)[ik, iE].T + im = self._variable("ImSelfEnergy", tree=tree)[ik, iE].T SE = self._E2eV * re + (1j * self._E2eV) * im if sort: @@ -103,7 +104,7 @@ def self_energy(self, elec, E, k=0, sort=False): return SE def broadening_matrix(self, elec, E, k=0, sort=False): - r""" Return the broadening matrix from the electrode `elec` + r"""Return the broadening matrix from the electrode `elec` The broadening matrix is calculated as: @@ -132,10 +133,10 @@ def broadening_matrix(self, elec, E, k=0, sort=False): # When storing fortran arrays in C-type files reading it in # C-codes will transpose the data. # So we have to transpose back to get the correct order - re = self._variable('ReSelfEnergy', tree=tree)[ik, iE].T - im = self._variable('ImSelfEnergy', tree=tree)[ik, iE].T + re = self._variable("ReSelfEnergy", tree=tree)[ik, iE].T + im = self._variable("ImSelfEnergy", tree=tree)[ik, iE].T - G = - self._E2eV * (im + im.T) + (1j * self._E2eV) * (re - re.T) + G = -self._E2eV * (im + im.T) + (1j * self._E2eV) * (re - re.T) if sort: pvt = self.pivot(elec) idx = np.argsort(pvt) @@ -147,7 +148,7 @@ def broadening_matrix(self, elec, E, k=0, sort=False): return G def self_energy_average(self, elec, E, sort=False): - """ Return the k-averaged average self-energy from the electrode `elec` + """Return the k-averaged average self-energy from the electrode `elec` Parameters ---------- @@ -168,8 +169,8 @@ def self_energy_average(self, elec, E, sort=False): # When storing fortran arrays in C-type files reading it in # C-codes will transpose the data. # So we have to transpose back to get the correct order - re = self._variable('ReSelfEnergyMean', tree=tree)[iE].T - im = self._variable('ImSelfEnergyMean', tree=tree)[iE].T + re = self._variable("ReSelfEnergyMean", tree=tree)[iE].T + im = self._variable("ImSelfEnergyMean", tree=tree)[iE].T SE = self._E2eV * re + (1j * self._E2eV) * im if sort: @@ -183,7 +184,7 @@ def self_energy_average(self, elec, E, sort=False): return SE def info(self, elec=None): - """ Information about the self-energy file available for extracting in this file + """Information about the self-energy file available for extracting in this file Parameters ---------- @@ -195,13 +196,14 @@ def info(self, elec=None): # Create a StringIO object to retain the information out = StringIO() + # Create wrapper function def prnt(*args, **kwargs): - option = kwargs.pop('option', None) + option = kwargs.pop("option", None) if option is None: print(*args, file=out) else: - print('{:60s}[{}]'.format(' '.join(args), ', '.join(option)), file=out) + print("{:60s}[{}]".format(" ".join(args), ", ".join(option)), file=out) # Retrieve the device atoms prnt("Device information:") @@ -213,14 +215,18 @@ def prnt(*args, **kwargs): nA = len(np.unique(kpt[:, 0])) nB = len(np.unique(kpt[:, 1])) nC = len(np.unique(kpt[:, 2])) - prnt((" - number of kpoints: {} <- " - "[ A = {} , B = {} , C = {} ] (time-reversal unknown)").format(self.nk, nA, nB, nC)) + prnt( + ( + " - number of kpoints: {} <- " + "[ A = {} , B = {} , C = {} ] (time-reversal unknown)" + ).format(self.nk, nA, nB, nC) + ) prnt(" - energy range:") E = self.E Em, EM = np.amin(E), np.amax(E) dE = np.diff(E) - dEm, dEM = np.amin(dE) * 1000, np.amax(dE) * 1000 # convert to meV - if (dEM - dEm) < 1e-3: # 0.001 meV + dEm, dEM = np.amin(dE) * 1000, np.amax(dE) * 1000 # convert to meV + if (dEM - dEm) < 1e-3: # 0.001 meV prnt(f" {Em:.5f} -- {EM:.5f} eV [{dEm:.3f} meV]") else: prnt(f" {Em:.5f} -- {EM:.5f} eV [{dEm:.3f} -- {dEM:.3f} meV]") @@ -246,16 +252,28 @@ def prnt(*args, **kwargs): try: n_btd = self.n_btd(elec) except Exception: - n_btd = 'unknown' + n_btd = "unknown" prnt() prnt(f"Electrode: {elec}") prnt(f" - number of BTD blocks: {n_btd}") prnt(" - Bloch: [{}, {}, {}]".format(*bloch)) - if 'TBT' in self._trans_type: - prnt(" - chemical potential: {:.4f} eV".format(self.chemical_potential(elec))) - prnt(" - electron temperature: {:.2f} K".format(self.electron_temperature(elec))) + if "TBT" in self._trans_type: + prnt( + " - chemical potential: {:.4f} eV".format( + self.chemical_potential(elec) + ) + ) + prnt( + " - electron temperature: {:.2f} K".format( + self.electron_temperature(elec) + ) + ) else: - prnt(" - phonon temperature: {:.4f} K".format(self.phonon_temperature(elec))) + prnt( + " - phonon temperature: {:.4f} K".format( + self.phonon_temperature(elec) + ) + ) prnt(" - imaginary part (eta): {:.4f} meV".format(self.eta(elec) * 1e3)) prnt(" - atoms in down-folding region (not in device):") prnt(" " + list2str(self.a_down(elec) + 1)) @@ -266,9 +284,11 @@ def prnt(*args, **kwargs): out.close() return s - @default_ArgumentParser(description="Show information about data in a TBT.SE.nc file") + @default_ArgumentParser( + description="Show information about data in a TBT.SE.nc file" + ) def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" # We limit the import to occur here import argparse @@ -276,29 +296,36 @@ def ArgumentParser(self, p=None, *args, **kwargs): namespace = default_namespace(_tbtse=self, _geometry=self.geom) class Info(argparse.Action): - """ Action to print information contained in the TBT.SE.nc file, helpful before performing actions """ + """Action to print information contained in the TBT.SE.nc file, helpful before performing actions""" def __call__(self, parser, ns, value, option_string=None): # First short-hand the file print(ns._tbtse.info(value)) - p.add_argument('--info', '-i', action=Info, nargs='?', metavar='ELEC', - help='Print out what information is contained in the TBT.SE.nc file, optionally only for one of the electrodes.') + p.add_argument( + "--info", + "-i", + action=Info, + nargs="?", + metavar="ELEC", + help="Print out what information is contained in the TBT.SE.nc file, optionally only for one of the electrodes.", + ) return p, namespace -add_sile('TBT.SE.nc', tbtsencSileTBtrans) +add_sile("TBT.SE.nc", tbtsencSileTBtrans) # Add spin-dependent files -add_sile('TBT_UP.SE.nc', tbtsencSileTBtrans) -add_sile('TBT_DN.SE.nc', tbtsencSileTBtrans) +add_sile("TBT_UP.SE.nc", tbtsencSileTBtrans) +add_sile("TBT_DN.SE.nc", tbtsencSileTBtrans) @set_module("sisl.io.phtrans") class phtsencSilePHtrans(tbtsencSileTBtrans): - """ PHtrans file object """ - _trans_type = 'PHT' - _E2eV = Ry2eV ** 2 + """PHtrans file object""" + + _trans_type = "PHT" + _E2eV = Ry2eV**2 -add_sile('PHT.SE.nc', phtsencSilePHtrans) +add_sile("PHT.SE.nc", phtsencSilePHtrans) diff --git a/src/sisl/io/tbtrans/sile.py b/src/sisl/io/tbtrans/sile.py index 340b36d171..4920a9feb0 100644 --- a/src/sisl/io/tbtrans/sile.py +++ b/src/sisl/io/tbtrans/sile.py @@ -5,7 +5,7 @@ from ..sile import Sile, SileBin, SileCDF -__all__ = ['SileTBtrans', 'SileCDFTBtrans', 'SileBinTBtrans'] +__all__ = ["SileTBtrans", "SileCDFTBtrans", "SileBinTBtrans"] @set_module("sisl.io.tbtrans") @@ -15,7 +15,6 @@ class SileTBtrans(Sile): @set_module("sisl.io.tbtrans") class SileCDFTBtrans(SileCDF): - # all netcdf output should not be masked def _setup(self, *args, **kwargs): super()._setup(*args, **kwargs) diff --git a/src/sisl/io/tbtrans/tbt.py b/src/sisl/io/tbtrans/tbt.py index 6d47a90c26..0350ed9ddd 100644 --- a/src/sisl/io/tbtrans/tbt.py +++ b/src/sisl/io/tbtrans/tbt.py @@ -27,6 +27,7 @@ from sisl.physics.distribution import fermi_dirac from sisl.sparse import _ncol_to_indptr from sisl.unit.siesta import unit_convert + # Import sile objects from sisl.utils import ( collect_action, @@ -41,17 +42,19 @@ from ..sile import add_sile, get_sile, sile_raise_write from ._cdf import _devncSileTBtrans -__all__ = ['tbtncSileTBtrans', 'tbtavncSileTBtrans'] +__all__ = ["tbtncSileTBtrans", "tbtavncSileTBtrans"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ry2eV = unit_convert('Ry', 'eV') -Ry2K = unit_convert('Ry', 'K') -eV2Ry = unit_convert('eV', 'Ry') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ry2eV = unit_convert("Ry", "eV") +Ry2K = unit_convert("Ry", "K") +eV2Ry = unit_convert("eV", "Ry") -def window_warning(routine, E, elec_from, mu_from, kt_from, elec_to, mu_to, kt_to, kT_factor=3): - """ Issue a warning if the energy grid does not the chemical potentials """ +def window_warning( + routine, E, elec_from, mu_from, kt_from, elec_to, mu_to, kt_to, kT_factor=3 +): + """Issue a warning if the energy grid does not the chemical potentials""" Emin = E.min() Emax = E.max() @@ -63,29 +66,43 @@ def window_warning(routine, E, elec_from, mu_from, kt_from, elec_to, mu_to, kt_t dE = E[1] - E[0] # Check that the lower bound is sufficient - print_warning = mu_from - kt_from * kT_factor < Emin - dE / 2 or \ - mu_to - kt_to * kT_factor < Emin - dE / 2 - print_warning = mu_from + kt_from * kT_factor > Emax + dE / 2 or \ - mu_to + kt_to * kT_factor > Emax + dE / 2 or \ - print_warning + print_warning = ( + mu_from - kt_from * kT_factor < Emin - dE / 2 + or mu_to - kt_to * kT_factor < Emin - dE / 2 + ) + print_warning = ( + mu_from + kt_from * kT_factor > Emax + dE / 2 + or mu_to + kt_to * kT_factor > Emax + dE / 2 + or print_warning + ) if print_warning: # We should pretty-print a table of data m = max(len(elec_from), len(elec_to), 15) - s = ("{:"+str(m)+"s} {:9.3f} : {:9.3f} eV\n").format('Energy range', Emin - dE / 2, Emax + dE / 2) - s += ("{:"+str(m)+"s} {:9.3f} : {:9.3f} eV\n").format(elec_from, mu_from - kt_from * kT_factor, mu_from + kt_from * kT_factor) - s += ("{:"+str(m)+"s} {:9.3f} : {:9.3f} eV\n").format(elec_to, mu_to - kt_to * kT_factor, mu_to + kt_to * kT_factor) + s = ("{:" + str(m) + "s} {:9.3f} : {:9.3f} eV\n").format( + "Energy range", Emin - dE / 2, Emax + dE / 2 + ) + s += ("{:" + str(m) + "s} {:9.3f} : {:9.3f} eV\n").format( + elec_from, mu_from - kt_from * kT_factor, mu_from + kt_from * kT_factor + ) + s += ("{:" + str(m) + "s} {:9.3f} : {:9.3f} eV\n").format( + elec_to, mu_to - kt_to * kT_factor, mu_to + kt_to * kT_factor + ) min_e = min(mu_from - kt_from * kT_factor, mu_to - kt_to * kT_factor) max_e = max(mu_from + kt_from * kT_factor, mu_to + kt_to * kT_factor) - s += ("{:"+str(m)+"s} {:9.3f} : {:9.3f} eV\n").format('dFermi function', min_e, max_e) + s += ("{:" + str(m) + "s} {:9.3f} : {:9.3f} eV\n").format( + "dFermi function", min_e, max_e + ) - warn(f"{routine} cannot " - "accurately calculate the current due to the calculated energy range. " - "Increase the calculated energy-range.\n{s}") + warn( + f"{routine} cannot " + "accurately calculate the current due to the calculated energy range. " + "Increase the calculated energy-range.\n{s}" + ) @set_module("sisl.io.tbtrans") class tbtncSileTBtrans(_devncSileTBtrans): - r""" TBtrans output file object + r"""TBtrans output file object Implementation of the TBtrans output ``*.TBT.nc`` files which contains calculated quantities related to the NEGF code TBtrans. @@ -111,13 +128,13 @@ class tbtncSileTBtrans(_devncSileTBtrans): The API for this class are largely equivalent to the arguments of the `sdata` command-line tool, with the execption that the command-line tool uses Fortran indexing numbers (1-based). """ - _trans_type = 'TBT' + _trans_type = "TBT" _E2eV = Ry2eV _k_avg = False def write_tbtav(self, *args, **kwargs): - """ Convert this to a TBT.AV.nc file, i.e. all k dependent quantites are averaged out. + """Convert this to a TBT.AV.nc file, i.e. all k dependent quantites are averaged out. This command will overwrite any previous file with the ending TBT.AV.nc and thus will not take notice of any older files. @@ -127,14 +144,14 @@ def write_tbtav(self, *args, **kwargs): file : str output filename """ - f = self._file.with_suffix('.AV.nc') + f = self._file.with_suffix(".AV.nc") if len(args) > 0: f = args[0] - f = kwargs.get('file', f) - tbtavncSileTBtrans(f, mode='w', access=0).write_tbtav(self) + f = kwargs.get("file", f) + tbtavncSileTBtrans(f, mode="w", access=0).write_tbtav(self) def _value_avg(self, name, tree=None, kavg=False): - """ Local method for obtaining the data from the SileCDF. + """Local method for obtaining the data from the SileCDF. This method checks how the file is access, i.e. whether data is stored in the object or it should be read consequtively. @@ -148,12 +165,16 @@ def _value_avg(self, name, tree=None, kavg=False): except KeyError as err: group = None if isinstance(tree, list): - group = '.'.join(tree) + group = ".".join(tree) elif not tree is None: group = tree if not group is None: - raise KeyError(f"{self.__class__.__name__} could not retrieve key '{group}.{name}' due to missing flags in the input file.") - raise KeyError(f"{self.__class__.__name__} could not retrieve key '{name}' due to missing flags in the input file.") + raise KeyError( + f"{self.__class__.__name__} could not retrieve key '{group}.{name}' due to missing flags in the input file." + ) + raise KeyError( + f"{self.__class__.__name__} could not retrieve key '{name}' due to missing flags in the input file." + ) if self._k_avg: return v[:] @@ -176,13 +197,15 @@ def _value_avg(self, name, tree=None, kavg=False): data.shape = orig_shape[1:] else: - raise ValueError(f"{self.__class__.__name__} requires kavg argument to be either bool or an integer corresponding to the k-point index.") + raise ValueError( + f"{self.__class__.__name__} requires kavg argument to be either bool or an integer corresponding to the k-point index." + ) # Return data return data def _value_E(self, name, tree=None, kavg=False, E=None): - """ Local method for obtaining the data from the SileCDF using an E index. """ + """Local method for obtaining the data from the SileCDF using an E index.""" if E is None: return self._value_avg(name, tree, kavg) @@ -194,12 +217,16 @@ def _value_E(self, name, tree=None, kavg=False, E=None): except KeyError: group = None if isinstance(tree, list): - group = '.'.join(tree) + group = ".".join(tree) elif not tree is None: group = tree if not group is None: - raise KeyError(f"{self.__class__.__name__} could not retrieve key '{group}.{name}' due to missing flags in the input file.") - raise KeyError(f"{self.__class__.__name__} could not retrieve key '{name}' due to missing flags in the input file.") + raise KeyError( + f"{self.__class__.__name__} could not retrieve key '{group}.{name}' due to missing flags in the input file." + ) + raise KeyError( + f"{self.__class__.__name__} could not retrieve key '{name}' due to missing flags in the input file." + ) if self._k_avg: return v[iE, ...] @@ -223,13 +250,15 @@ def _value_E(self, name, tree=None, kavg=False, E=None): data.shape = orig_shape[2:] else: - raise ValueError(f"{self.__class__.__name__} requires kavg argument to be either bool or an integer corresponding to the k-point index.") + raise ValueError( + f"{self.__class__.__name__} requires kavg argument to be either bool or an integer corresponding to the k-point index." + ) # Return data return data def transmission(self, elec_from=0, elec_to=1, kavg=True) -> ndarray: - r""" Transmission from `elec_from` to `elec_to`. + r"""Transmission from `elec_from` to `elec_to`. The transmission between two electrodes may be retrieved from the `Sile`. @@ -261,12 +290,14 @@ def transmission(self, elec_from=0, elec_to=1, kavg=True) -> ndarray: elec_from = self._elec(elec_from) elec_to = self._elec(elec_to) if elec_from == elec_to: - raise ValueError(f"{self.__class__.__name__}.transmission elec_from[{elec_from}] and elec_to[{elec_to}] must not be the same.") + raise ValueError( + f"{self.__class__.__name__}.transmission elec_from[{elec_from}] and elec_to[{elec_to}] must not be the same." + ) return self._value_avg(f"{elec_to}.T", elec_from, kavg=kavg) def reflection(self, elec=0, kavg=True, from_single=False) -> ndarray: - r""" Reflection into electrode `elec` + r"""Reflection into electrode `elec` The reflection into electrode `elec` is calculated as: @@ -306,9 +337,11 @@ def reflection(self, elec=0, kavg=True, from_single=False) -> ndarray: # Find full transmission out of electrode if from_single: - T = self._value_avg(f"{elec}.T", elec, kavg=kavg) - self._value_avg(f"{elec}.C", elec, kavg=kavg) + T = self._value_avg(f"{elec}.T", elec, kavg=kavg) - self._value_avg( + f"{elec}.C", elec, kavg=kavg + ) else: - T = 0. + T = 0.0 for to in self.elecs: to = self._elec(to) if elec == to: @@ -318,7 +351,7 @@ def reflection(self, elec=0, kavg=True, from_single=False) -> ndarray: return BT - T def transmission_eig(self, elec_from=0, elec_to=1, kavg=True) -> ndarray: - """ Transmission eigenvalues from `elec_from` to `elec_to`. + """Transmission eigenvalues from `elec_from` to `elec_to`. Parameters ---------- @@ -338,12 +371,14 @@ def transmission_eig(self, elec_from=0, elec_to=1, kavg=True) -> ndarray: elec_from = self._elec(elec_from) elec_to = self._elec(elec_to) if elec_from == elec_to: - raise ValueError(f"{self.__class__.__name__}.transmission_eig elec_from[{elec_from}] and elec_to[{elec_to}] must not be the same.") + raise ValueError( + f"{self.__class__.__name__}.transmission_eig elec_from[{elec_from}] and elec_to[{elec_to}] must not be the same." + ) return self._value_avg(f"{elec_to}.T.Eig", elec_from, kavg=kavg) def transmission_bulk(self, elec=0, kavg=True) -> ndarray: - """ Bulk transmission for the `elec` electrode + """Bulk transmission for the `elec` electrode The bulk transmission is equivalent to creating a 2 terminal device with electrode `elec` tiled 3 times. @@ -364,8 +399,8 @@ def transmission_bulk(self, elec=0, kavg=True) -> ndarray: """ return self._value_avg("T", self._elec(elec), kavg=kavg) - def norm(self, atoms=None, orbitals=None, norm='none') -> int: - r""" Normalization factor depending on the input + def norm(self, atoms=None, orbitals=None, norm="none") -> int: + r"""Normalization factor depending on the input The normalization can be performed in one of the below methods. In the following :math:`N` refers to the normalization constant @@ -397,12 +432,14 @@ def norm(self, atoms=None, orbitals=None, norm='none') -> int: """ # Cast to lower norm = norm.lower() - if norm == 'none': + if norm == "none": NORM = 1 - elif norm in ['all', 'atom', 'orbital']: + elif norm in ["all", "atom", "orbital"]: NORM = self.no_d else: - raise ValueError(f"{self.__class__.__name__}.norm error on norm keyword in when requesting normalization!") + raise ValueError( + f"{self.__class__.__name__}.norm error on norm keyword in when requesting normalization!" + ) # If the user simply requests a specific norm if atoms is None and orbitals is None: @@ -411,9 +448,9 @@ def norm(self, atoms=None, orbitals=None, norm='none') -> int: # Now figure out what to do if atoms is None: # Get pivoting indices to average over - if norm == 'orbital': + if norm == "orbital": NORM = len(self.o2p(orbitals)) - elif norm == 'atom': + elif norm == "atom": geom = self.geometry a = np.unique(geom.o2a(orbitals)) # Now sum the orbitals per atom @@ -421,7 +458,9 @@ def norm(self, atoms=None, orbitals=None, norm='none') -> int: return NORM if not orbitals is None: - raise ValueError(f"{self.__class__.__name__}.norm both atom and orbital cannot be specified!") + raise ValueError( + f"{self.__class__.__name__}.norm both atom and orbital cannot be specified!" + ) # atom is specified, this will result in the same normalization # regardless of norm == [orbital, atom] since it is all orbitals @@ -432,7 +471,7 @@ def norm(self, atoms=None, orbitals=None, norm='none') -> int: return NORM def _DOS(self, DOS, atoms, orbitals, sum, norm) -> ndarray: - """ Averages/sums the DOS + """Averages/sums the DOS Parameters ---------- @@ -457,17 +496,21 @@ def _DOS(self, DOS, atoms, orbitals, sum, norm) -> ndarray: """ # Force False equivalent as None. if isinstance(atoms, bool): - if not atoms: atoms = None + if not atoms: + atoms = None if isinstance(orbitals, bool): - if not orbitals: orbitals = None + if not orbitals: + orbitals = None if not atoms is None and not orbitals is None: - raise ValueError("Both atoms and orbitals keyword in DOS request " - "cannot be specified, only one at a time.") + raise ValueError( + "Both atoms and orbitals keyword in DOS request " + "cannot be specified, only one at a time." + ) # Cast to lower norm = norm.lower() - if norm == 'none': - NORM = 1. - elif norm in ['all', 'atom', 'orbital']: + if norm == "none": + NORM = 1.0 + elif norm in ["all", "atom", "orbital"]: NORM = float(self.no_d) else: raise ValueError("Error on norm keyword in DOS request") @@ -487,18 +530,18 @@ def _DOS(self, DOS, atoms, orbitals, sum, norm) -> ndarray: # orbital *must* be specified if isinstance(orbitals, bool): # Request all orbitals of the device - orbitals = geom.a2o('Device', all=True) + orbitals = geom.a2o("Device", all=True) elif isinstance(orbitals, str): orbitals = geom.a2o(orbitals, all=True) # Get pivoting indices to average over p = self.o2p(orbitals) - if norm == 'orbital': + if norm == "orbital": NORM = float(len(p)) - elif norm == 'atom': + elif norm == "atom": a = geom.o2a(orbitals, unique=True) # Now sum the orbitals per atom - NORM = float(_a.sumi(geom.firsto[a+1] - geom.firsto[a])) + NORM = float(_a.sumi(geom.firsto[a + 1] - geom.firsto[a])) if sum: return DOS[..., p].sum(-1) / NORM @@ -508,14 +551,14 @@ def _DOS(self, DOS, atoms, orbitals, sum, norm) -> ndarray: # Check if user requests all atoms/orbitals if isinstance(atoms, bool): # Request all atoms of the device - atoms = geom.names['Device'] + atoms = geom.names["Device"] elif isinstance(atoms, str): atoms = geom.names[atoms] # atom is specified # Return the pivoting orbitals for the atom p = self.a2p(atoms) - if norm in ['orbital', 'atom']: + if norm in ["orbital", "atom"]: NORM = float(len(p)) if sum or isinstance(atoms, Integral): @@ -541,8 +584,10 @@ def _DOS(self, DOS, atoms, orbitals, sum, norm) -> ndarray: return nDOS - def DOS(self, E=None, kavg=True, atoms=None, orbitals=None, sum=True, norm='none') -> ndarray: - r""" Green function density of states (DOS) (1/eV). + def DOS( + self, E=None, kavg=True, atoms=None, orbitals=None, sum=True, norm="none" + ) -> ndarray: + r"""Green function density of states (DOS) (1/eV). Extract the DOS on a selected subset of atoms/orbitals in the device region @@ -578,10 +623,22 @@ def DOS(self, E=None, kavg=True, atoms=None, orbitals=None, sum=True, norm='none ADOS : the spectral density of states from an electrode BDOS : the bulk density of states in an electrode """ - return self._DOS(self._value_E('DOS', kavg=kavg, E=E), atoms, orbitals, sum, norm) * eV2Ry - - def ADOS(self, elec=0, E=None, kavg=True, atoms=None, orbitals=None, sum=True, norm='none') -> ndarray: - r""" Spectral density of states (DOS) (1/eV). + return ( + self._DOS(self._value_E("DOS", kavg=kavg, E=E), atoms, orbitals, sum, norm) + * eV2Ry + ) + + def ADOS( + self, + elec=0, + E=None, + kavg=True, + atoms=None, + orbitals=None, + sum=True, + norm="none", + ) -> ndarray: + r"""Spectral density of states (DOS) (1/eV). Extract the spectral DOS from electrode `elec` on a selected subset of atoms/orbitals in the device region @@ -619,10 +676,15 @@ def ADOS(self, elec=0, E=None, kavg=True, atoms=None, orbitals=None, sum=True, n BDOS : the bulk density of states in an electrode """ elec = self._elec(elec) - return self._DOS(self._value_E('ADOS', elec, kavg=kavg, E=E), atoms, orbitals, sum, norm) * eV2Ry + return ( + self._DOS( + self._value_E("ADOS", elec, kavg=kavg, E=E), atoms, orbitals, sum, norm + ) + * eV2Ry + ) - def BDOS(self, elec=0, E=None, kavg=True, sum=True, norm='none') -> ndarray: - r""" Bulk density of states (DOS) (1/eV). + def BDOS(self, elec=0, E=None, kavg=True, sum=True, norm="none") -> ndarray: + r"""Bulk density of states (DOS) (1/eV). Extract the bulk DOS from electrode `elec` on a selected subset of atoms/orbitals in the device region @@ -654,18 +716,18 @@ def BDOS(self, elec=0, E=None, kavg=True, sum=True, norm='none') -> ndarray: # Hence the non-normalized quantity needs to be multiplied by # product(bloch) elec = self._elec(elec) - if norm in ['atom', 'orbital', 'all']: + if norm in ["atom", "orbital", "all"]: # This is normalized per non-expanded unit-cell, so no need to do Bloch - fact = eV2Ry / len(self._dimension('no_u', elec)) + fact = eV2Ry / len(self._dimension("no_u", elec)) else: fact = eV2Ry if sum: - return self._value_E('DOS', elec, kavg=kavg, E=E).sum(-1) * fact + return self._value_E("DOS", elec, kavg=kavg, E=E).sum(-1) * fact else: - return self._value_E('DOS', elec, kavg=kavg, E=E) * fact + return self._value_E("DOS", elec, kavg=kavg, E=E) * fact def _E_T_sorted(self, elec_from, elec_to, kavg=True): - """ Internal routine for returning energies and transmission in a sorted array """ + """Internal routine for returning energies and transmission in a sorted array""" E = self.E idx_sort = np.argsort(E) # Get transmission @@ -673,7 +735,7 @@ def _E_T_sorted(self, elec_from, elec_to, kavg=True): return E[idx_sort], T[idx_sort] def current(self, elec_from=0, elec_to=1, kavg=True) -> float: - r""" Current from `from` to `to` using the k-weights and energy spacings in the file. + r"""Current from `from` to `to` using the k-weights and energy spacings in the file. Calculates the current as: @@ -702,12 +764,12 @@ def current(self, elec_from=0, elec_to=1, kavg=True) -> float: kt_f = self.kT(elec_from) mu_t = self.chemical_potential(elec_to) kt_t = self.kT(elec_to) - return self.current_parameter(elec_from, mu_f, kt_f, - elec_to, mu_t, kt_t, kavg) + return self.current_parameter(elec_from, mu_f, kt_f, elec_to, mu_t, kt_t, kavg) - def current_parameter(self, elec_from, mu_from, kt_from, - elec_to, mu_to, kt_to, kavg=True) -> float: - r""" Current from `from` to `to` using the k-weights and energy spacings in the file. + def current_parameter( + self, elec_from, mu_from, kt_from, elec_to, mu_to, kt_to, kavg=True + ) -> float: + r"""Current from `from` to `to` using the k-weights and energy spacings in the file. Calculates the current as: @@ -745,21 +807,32 @@ def current_parameter(self, elec_from, mu_from, kt_from, E, T = self._E_T_sorted(elec_from, elec_to, kavg) dE = E[1] - E[0] - window_warning(f"{self.__class__.__name__}.current_parameter", E, - elec_from, mu_from, kt_from, - elec_to, mu_to, kt_to) - - I = (T * dE * (fermi_dirac(E, kt_from, mu_from) - fermi_dirac(E, kt_to, mu_to))).sum() - return I * constant.q / constant.h('eV s') + window_warning( + f"{self.__class__.__name__}.current_parameter", + E, + elec_from, + mu_from, + kt_from, + elec_to, + mu_to, + kt_to, + ) + + I = ( + T * dE * (fermi_dirac(E, kt_from, mu_from) - fermi_dirac(E, kt_to, mu_to)) + ).sum() + return I * constant.q / constant.h("eV s") def _check_Teig(self, func_name, TE, eps=0.001): - """ Internal method to check whether all transmission eigenvalues are present """ + """Internal method to check whether all transmission eigenvalues are present""" if np.any(np.logical_and.reduce(TE > eps, axis=-1)): - info(f"{self.__class__.__name__}.{func_name} does possibly not have all relevant transmission eigenvalues in the " - "calculation. For some energy values all transmission eigenvalues are above {eps}!") + info( + f"{self.__class__.__name__}.{func_name} does possibly not have all relevant transmission eigenvalues in the " + "calculation. For some energy values all transmission eigenvalues are above {eps}!" + ) def shot_noise(self, elec_from=0, elec_to=1, classical=False, kavg=True) -> ndarray: - r""" Shot-noise term `from` to `to` using the k-weights + r"""Shot-noise term `from` to `to` using the k-weights Calculates the shot-noise term according to `classical` (also known as the Poisson value). If `classical` is True the shot-noise calculated is: @@ -801,7 +874,7 @@ def shot_noise(self, elec_from=0, elec_to=1, classical=False, kavg=True) -> ndar # Pre-factor # 2 e ^ 3 V / h # Note that h in eV units will cancel the units in the applied bias - noise_const = 2 * constant.q ** 2 * (eV / constant.h('eV s')) + noise_const = 2 * constant.q**2 * (eV / constant.h("eV s")) if classical: # Calculate the Poisson shot-noise (equal to 2eI in the low T and zero kT limit) return noise_const * self.transmission(elec_from, elec_to, kavg=kavg) @@ -811,29 +884,29 @@ def shot_noise(self, elec_from=0, elec_to=1, classical=False, kavg=True) -> ndar if not kavg: # The user wants it k-resolved T = self.transmission_eig(elec_from, elec_to, kavg=False) - self._check_Teig('shot_noise', T) + self._check_Teig("shot_noise", T) return noise_const * (T * (1 - T)).sum(-1) # We need to manually weigh the k-points wkpt = self.wkpt T = self.transmission_eig(elec_from, elec_to, kavg=0) - self._check_Teig('shot_noise', T) + self._check_Teig("shot_noise", T) sn = (T * (1 - T)).sum(-1) * wkpt[0] for ik in range(1, self.nkpt): T = self.transmission_eig(elec_from, elec_to, kavg=ik) - self._check_Teig('shot_noise', T) + self._check_Teig("shot_noise", T) sn += (T * (1 - T)).sum(-1) * wkpt[ik] else: T = self.transmission_eig(elec_from, elec_to, kavg=kavg) - self._check_Teig('shot_noise', T) + self._check_Teig("shot_noise", T) sn = (T * (1 - T)).sum(-1) return noise_const * sn def noise_power(self, elec_from=0, elec_to=1, kavg=True) -> ndarray: - r""" Noise power `from` to `to` using the k-weights and energy spacings in the file (temperature dependent) + r"""Noise power `from` to `to` using the k-weights and energy spacings in the file (temperature dependent) Calculates the noise power as @@ -882,32 +955,42 @@ def noise_power(self, elec_from=0, elec_to=1, kavg=True) -> ndarray: # Pre-factor # 2 e ^ 2 / h # Note that h in eV units will cancel the units in the dE integration - noise_const = 2 * constant.q ** 2 / constant.h('eV s') + noise_const = 2 * constant.q**2 / constant.h("eV s") # Determine the k-average if isinstance(kavg, bool): if not kavg: # The user wants it k-resolved T = self.transmission_eig(elec_from, elec_to, kavg=False) - self._check_Teig('noise_power', T) - return noise_const * ((T.sum(-1) * eq_fac).sum(-1) + ((T * (1 - T)).sum(-1) * neq_fac).sum(-1)) + self._check_Teig("noise_power", T) + return noise_const * ( + (T.sum(-1) * eq_fac).sum(-1) + + ((T * (1 - T)).sum(-1) * neq_fac).sum(-1) + ) # We need to manually weigh the k-points wkpt = self.wkpt T = self.transmission_eig(elec_from, elec_to, kavg=0) - self._check_Teig('noise_power', T) + self._check_Teig("noise_power", T) # Separate the calculation into two terms (see Ya.M. Blanter, M. Buttiker, Physics Reports 336 2000) - np = ((T.sum(-1) * eq_fac).sum(-1) + ((T * (1 - T)).sum(-1) * neq_fac).sum(-1)) * wkpt[0] + np = ( + (T.sum(-1) * eq_fac).sum(-1) + ((T * (1 - T)).sum(-1) * neq_fac).sum(-1) + ) * wkpt[0] for ik in range(1, self.nkpt): T = self.transmission_eig(elec_from, elec_to, kavg=ik) - self._check_Teig('noise_power', T) - np += ((T.sum(-1) * eq_fac).sum(-1) + ((T * (1 - T)).sum(-1) * neq_fac).sum(-1)) * wkpt[ik] + self._check_Teig("noise_power", T) + np += ( + (T.sum(-1) * eq_fac).sum(-1) + + ((T * (1 - T)).sum(-1) * neq_fac).sum(-1) + ) * wkpt[ik] else: T = self.transmission_eig(elec_from, elec_to, kavg=kavg) - self._check_Teig('noise_power', T) - np = (T.sum(-1) * eq_fac).sum(-1) + ((T * (1 - T)).sum(-1) * neq_fac).sum(-1) + self._check_Teig("noise_power", T) + np = (T.sum(-1) * eq_fac).sum(-1) + ((T * (1 - T)).sum(-1) * neq_fac).sum( + -1 + ) # Do final conversion return noise_const * np @@ -956,45 +1039,46 @@ def fano(self, elec_from=0, elec_to=1, kavg=True, zero_T=1e-6) -> ndarray: shot_noise : shot-noise term (zero temperature limit) noise_power : temperature dependent noise power """ + def dividend(T): - T[T <= zero_T] = 0. + T[T <= zero_T] = 0.0 return (T * (1 - T)).sum(-1) if isinstance(kavg, bool): if not kavg: # The user wants it k-resolved T = self.transmission_eig(elec_from, elec_to, kavg=False) - self._check_Teig('fano', T) + self._check_Teig("fano", T) fano = dividend(T) T = self.transmission(elec_from, elec_to) fano /= T[None, :] - fano[:, T <= 0.] = 0. + fano[:, T <= 0.0] = 0.0 return fano # We need to manually weigh the k-points wkpt = self.wkpt T = self.transmission_eig(elec_from, elec_to, kavg=0) - self._check_Teig('fano', T) + self._check_Teig("fano", T) fano = dividend(T) * wkpt[0] for ik in range(1, self.nkpt): T = self.transmission_eig(elec_from, elec_to, kavg=ik) - self._check_Teig('fano', T) + self._check_Teig("fano", T) fano += dividend(T) * wkpt[ik] else: T = self.transmission_eig(elec_from, elec_to, kavg=kavg) - self._check_Teig('fano', T) + self._check_Teig("fano", T) fano = dividend(T) # Divide by k-averaged transmission T = self.transmission(elec_from, elec_to) fano /= T - fano[T <= 0.] = 0. + fano[T <= 0.0] = 0.0 return fano def _sparse_data(self, name, elec, E, kavg=True) -> ndarray: - """ Internal routine for retrieving sparse data (orbital current, COOP) """ + """Internal routine for retrieving sparse data (orbital current, COOP)""" if elec is not None: elec = self._elec(elec) @@ -1002,15 +1086,15 @@ def _sparse_data(self, name, elec, E, kavg=True) -> ndarray: return self._value_E(name, elec, kavg, E) def _sparse_data_to_matrix(self, data, isc=None, orbitals=None) -> csr_matrix: - """ Internal routine for retrieving sparse data (orbital current, COOP) """ + """Internal routine for retrieving sparse data (orbital current, COOP)""" # Get the geometry for obtaining the sparsity pattern. geom = self.geometry # These are the row-pointers... - ncol = self._value('n_col') + ncol = self._value("n_col") # Get column indices - col = self._value('list_col') - 1 + col = self._value("list_col") - 1 # get subset orbitals if not orbitals is None: @@ -1026,9 +1110,7 @@ def _sparse_data_to_matrix(self, data, isc=None, orbitals=None) -> csr_matrix: # now figure out all places where we # have the corresponding values - all_col = np.logical_and( - np.in1d(row, all_col), - np.in1d(col, all_col)) + all_col = np.logical_and(np.in1d(row, all_col), np.in1d(col, all_col)) # reduce space col = col[all_col] @@ -1070,8 +1152,9 @@ def _sparse_data_to_matrix(self, data, isc=None, orbitals=None) -> csr_matrix: def ret_range(val, req): i = val // 2 if req is None: - return range(-i, i+1) + return range(-i, i + 1) return [req] + x = ret_range(nsc[0], isc[0]) y = ret_range(nsc[1], isc[1]) z = ret_range(nsc[2], isc[2]) @@ -1083,8 +1166,9 @@ def ret_range(val, req): all_col[i] = geom.sc_index([ix, iy, iz]) # Transfer all_col to the range - all_col = _a.array_arangei(all_col * geom.no, - n=_a.fulli(len(all_col), geom.no)) + all_col = _a.array_arangei( + all_col * geom.no, n=_a.fulli(len(all_col), geom.no) + ) # get both row and column indices row_nonzero = (ncol > 0).nonzero()[0] @@ -1108,13 +1192,15 @@ def ret_range(val, req): return csr_matrix((data, col, rptr), shape=mat_size) - def _sparse_matrix(self, name, elec, E, kavg=True, isc=None, orbitals=None) -> csr_matrix: - """ Internal routine for retrieving sparse matrices (orbital current, COOP) """ + def _sparse_matrix( + self, name, elec, E, kavg=True, isc=None, orbitals=None + ) -> csr_matrix: + """Internal routine for retrieving sparse matrices (orbital current, COOP)""" data = self._sparse_data(name, elec, E, kavg) return self._sparse_data_to_matrix(data, isc, orbitals) def sparse_orbital_to_atom(self, Dij, uc=False, sum_dup=True) -> csr_matrix: - """ Reduce a sparse matrix in orbital sparse to a sparse matrix in atomic indices + """Reduce a sparse matrix in orbital sparse to a sparse matrix in atomic indices This algorithm *may* keep the same non-zero entries, but will return a new csr_matrix with duplicate indices. @@ -1158,7 +1244,7 @@ def map_col(c): map_col = o2a # Lets do array notation for speeding up the computations - if not (issparse(Dij) and Dij.format == 'csr'): + if not (issparse(Dij) and Dij.format == "csr"): Dij = Dij.tocsr() # Check for the simple case of 1-orbital systems @@ -1202,7 +1288,7 @@ def map_col(c): @wrap_filterwarnings("ignore", category=SparseEfficiencyWarning) def sparse_atom_to_vector(self, Dab) -> ndarray: - """ Reduce an atomic sparse matrix to a vector contribution of each atom + """Reduce an atomic sparse matrix to a vector contribution of each atom Notes ----- @@ -1231,7 +1317,7 @@ def sparse_atom_to_vector(self, Dab) -> ndarray: Dia = getrow(ia) # Set diagonal to zero - Dia[0, ia] = 0. + Dia[0, ia] = 0.0 # Remove the diagonal (prohibits the calculation of the # norm of the zero vector, hence required) Dia.eliminate_zeros() @@ -1239,13 +1325,13 @@ def sparse_atom_to_vector(self, Dab) -> ndarray: # Now calculate the vector elements # Remark that the vector goes from ia -> ja rv = Rij(ia, Dia.indices) - rv = rv / np.sqrt((rv ** 2).sum(1))[:, None] + rv = rv / np.sqrt((rv**2).sum(1))[:, None] V[ia, :] = (Dia.data[:, None] * rv).sum(0) return V def sparse_orbital_to_vector(self, Dij, uc=False, sum_dup=True) -> ndarray: - """ Reduce an orbital sparse matrix to a vector contribution of each atom + """Reduce an orbital sparse matrix to a vector contribution of each atom Equivalent to calling `sparse_orbital_to_atom` and `sparse_atom_to_vector`. @@ -1332,15 +1418,16 @@ def sparse_orbital_to_scalar(self, Dij, activity=True) -> ndarray: return Da - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def orbital_transmission(self, E, elec=0, - kavg=True, - isc=None, - what: str="all", - orbitals=None) -> csr_matrix: - r""" Transmission at energy `E` between orbitals originating from `elec` + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def orbital_transmission( + self, E, elec=0, kavg=True, isc=None, what: str = "all", orbitals=None + ) -> csr_matrix: + r"""Transmission at energy `E` between orbitals originating from `elec` Each matrix element of the sparse matrix corresponds to the orbital indices of the underlying geometry (including buffer and electrode atoms). @@ -1418,15 +1505,17 @@ def orbital_transmission(self, E, elec=0, atom_transmission : energy resolved atomic transmission for each atom (scalar representation of bond-transmissions) atom_current : the atomic current for each atom (scalar representation of bond-currents) """ - J = self._sparse_matrix('J', elec, E, kavg, isc, orbitals) + J = self._sparse_matrix("J", elec, E, kavg, isc, orbitals) if what in ("+", "out"): J.data[J.data < 0] = 0 elif what in ("-", "in"): J.data[J.data > 0] = 0 elif what not in ("all", "both", "+-", "-+", "inout", "outin"): - raise ValueError(f"{self.__class__.__name__}.orbital_transmission 'what' keyword has " - "wrong value [all/both/+-, +/out,-/in] allowed.") + raise ValueError( + f"{self.__class__.__name__}.orbital_transmission 'what' keyword has " + "wrong value [all/both/+-, +/out,-/in] allowed." + ) # do not delete explicit 0's as the user can then know the sparse matrices # calculated. @@ -1434,12 +1523,22 @@ def orbital_transmission(self, E, elec=0, return J - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def orbital_current(self, elec=0, elec_other=1, kavg=True, isc=None, - what: str="all", orbitals=None) -> csr_matrix: - r""" Orbital current originating from `elec` as a sparse matrix + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def orbital_current( + self, + elec=0, + elec_other=1, + kavg=True, + isc=None, + what: str = "all", + orbitals=None, + ) -> csr_matrix: + r"""Orbital current originating from `elec` as a sparse matrix This is the bias window integrated quantity of `orbital_transmission`. As such it represents how the current is flowing at an applied bias from a given electrode. @@ -1524,19 +1623,29 @@ def func_all(data, A): }.get(what) if getdata is None: - raise ValueError(f"{self.__class__.__name__}.orbital_current 'what' keyword has " - "wrong value [all/both/+-/inout, +/out,-/in] allowed.") - - J = reduce(getdata, enumerate(integrator(self.E)), 0.) - - return self._sparse_data_to_matrix(J, isc, orbitals) * constant.q / constant.h("eV s") - - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def bond_transmission(self, E, elec=0, kavg=True, isc=None, - what: str="all", orbitals=None, uc=False) -> csr_matrix: - r""" Bond transmission between atoms at a specific energy + raise ValueError( + f"{self.__class__.__name__}.orbital_current 'what' keyword has " + "wrong value [all/both/+-/inout, +/out,-/in] allowed." + ) + + J = reduce(getdata, enumerate(integrator(self.E)), 0.0) + + return ( + self._sparse_data_to_matrix(J, isc, orbitals) + * constant.q + / constant.h("eV s") + ) + + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def bond_transmission( + self, E, elec=0, kavg=True, isc=None, what: str = "all", orbitals=None, uc=False + ) -> csr_matrix: + r"""Bond transmission between atoms at a specific energy Short hand function for calling `orbital_transmission` and `sparse_orbital_to_atom`. @@ -1591,17 +1700,29 @@ def bond_transmission(self, E, elec=0, kavg=True, isc=None, atom_transmission : energy resolved atomic transmission for each atom (scalar representation of bond-transmissions) atom_current : the atomic current for each atom (scalar representation of bond-currents) """ - Jij = self.orbital_transmission(E, elec, kavg=kavg, isc=isc, - what=what, orbitals=orbitals) + Jij = self.orbital_transmission( + E, elec, kavg=kavg, isc=isc, what=what, orbitals=orbitals + ) return self.sparse_orbital_to_atom(Jij, uc=uc) - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def bond_current(self, elec=0, elec_other=1, kavg=True, isc=None, - what: str="all", orbitals=None, uc=False) -> csr_matrix: - r""" Bond current between atoms (sum of orbital currents) + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def bond_current( + self, + elec=0, + elec_other=1, + kavg=True, + isc=None, + what: str = "all", + orbitals=None, + uc=False, + ) -> csr_matrix: + r"""Bond current between atoms (sum of orbital currents) Short hand function for calling `orbital_current` and `sparse_orbital_to_atom`. @@ -1664,17 +1785,22 @@ def bond_current(self, elec=0, elec_other=1, kavg=True, isc=None, atom_transmission : energy resolved atomic transmission for each atom (scalar representation of bond-transmissions) atom_current : the atomic current for each atom (scalar representation of bond-currents) """ - Jij = self.orbital_current(elec, elec_other, kavg=kavg, isc=isc, - what=what, orbitals=orbitals) + Jij = self.orbital_current( + elec, elec_other, kavg=kavg, isc=isc, what=what, orbitals=orbitals + ) return self.sparse_orbital_to_atom(Jij, uc=uc) - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def vector_transmission(self, E, elec=0, kavg=True, isc=None, - what="all", orbitals=None) -> ndarray: - r""" Vector for each atom being the sum of bond transmissions times the normalized bond vector between the atoms + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def vector_transmission( + self, E, elec=0, kavg=True, isc=None, what="all", orbitals=None + ) -> ndarray: + r"""Vector for each atom being the sum of bond transmissions times the normalized bond vector between the atoms The vector transmission is defined as: @@ -1725,8 +1851,9 @@ def vector_transmission(self, E, elec=0, kavg=True, isc=None, atom_transmission : energy resolved atomic transmission for each atom (scalar representation of bond-transmissions) atom_current : the atomic current for each atom (scalar representation of bond-currents) """ - Jab = self.bond_transmission(E, elec, kavg=kavg, isc=isc, - what=what, orbitals=orbitals) + Jab = self.bond_transmission( + E, elec, kavg=kavg, isc=isc, what=what, orbitals=orbitals + ) if what in ("all", "both", "+-", "-+", "inout", "outin"): # When we divide by two one can *always* compare the bulk @@ -1737,12 +1864,22 @@ def vector_transmission(self, E, elec=0, kavg=True, isc=None, return self.sparse_atom_to_vector(Jab) - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def vector_current(self, elec=0, elec_other=1, kavg=True, isc=None, - what: str="all", orbitals=None) -> ndarray: - r""" Vector for each atom being the sum of bond currents times the normalized bond vector between the atoms + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def vector_current( + self, + elec=0, + elec_other=1, + kavg=True, + isc=None, + what: str = "all", + orbitals=None, + ) -> ndarray: + r"""Vector for each atom being the sum of bond currents times the normalized bond vector between the atoms The vector current is defined as: @@ -1799,8 +1936,9 @@ def vector_current(self, elec=0, elec_other=1, kavg=True, isc=None, atom_transmission : energy resolved atomic transmission for each atom (scalar representation of bond-transmissions) atom_current : the atomic current for each atom (scalar representation of bond-currents) """ - Jab = self.bond_current(elec, elec_other, kavg=kavg, isc=isc, - what=what, orbitals=orbitals) + Jab = self.bond_current( + elec, elec_other, kavg=kavg, isc=isc, what=what, orbitals=orbitals + ) if what in ("all", "both", "+-", "-+", "inout", "outin"): # When we divide by two one can *always* compare the bulk @@ -1811,7 +1949,9 @@ def vector_current(self, elec=0, elec_other=1, kavg=True, isc=None, return self.sparse_atom_to_vector(Jab) - def atom_transmission(self, E, elec=0, activity=True, kavg=True, isc=None, orbitals=None) -> ndarray: + def atom_transmission( + self, E, elec=0, activity=True, kavg=True, isc=None, orbitals=None + ) -> ndarray: r""" Atomic transmission at energy `E` of atoms, a scalar quantity quantifying how much transmission flows through an atom The atomic transmission is a single number specifying a figure of the *magnitude* @@ -1875,12 +2015,15 @@ def atom_transmission(self, E, elec=0, activity=True, kavg=True, isc=None, orbit vector_current : an atomic field current for each atom (Cartesian representation of bond-currents) atom_current : the atomic current for each atom (scalar representation of bond-currents) """ - Jij = self.orbital_transmission(E, elec, kavg=kavg, isc=isc, - what="all", orbitals=orbitals) + Jij = self.orbital_transmission( + E, elec, kavg=kavg, isc=isc, what="all", orbitals=orbitals + ) return self.sparse_orbital_to_scalar(Jij, activity=activity) - def atom_current(self, elec=0, elec_other=1, activity=True, kavg=True, isc=None, orbitals=None) -> ndarray: + def atom_current( + self, elec=0, elec_other=1, activity=True, kavg=True, isc=None, orbitals=None + ) -> ndarray: r""" Atomic current of atoms, a scalar quantity quantifying how much currents flows through an atom The atomic current is a single number specifying a figure of the *magnitude* @@ -1950,13 +2093,16 @@ def atom_current(self, elec=0, elec_other=1, activity=True, kavg=True, isc=None, vector_current : an atomic field current for each atom (Cartesian representation of bond-currents) atom_transmission : energy resolved atomic transmission for each atom (scalar representation of bond-transmissions) """ - Jij = self.orbital_current(elec, elec_other, kavg=kavg, isc=isc, - what="all", orbitals=orbitals) + Jij = self.orbital_current( + elec, elec_other, kavg=kavg, isc=isc, what="all", orbitals=orbitals + ) return self.sparse_orbital_to_scalar(Jij, activity=activity) - def density_matrix(self, E, kavg=True, isc=None, orbitals=None, geometry=None) -> csr_matrix: - r""" Density matrix from the Green function at energy `E` (1/eV) + def density_matrix( + self, E, kavg=True, isc=None, orbitals=None, geometry=None + ) -> csr_matrix: + r"""Density matrix from the Green function at energy `E` (1/eV) The density matrix can be used to calculate the LDOS in real-space. @@ -2002,10 +2148,14 @@ def density_matrix(self, E, kavg=True, isc=None, orbitals=None, geometry=None) - DensityMatrix object containing the Geometry and the density matrix elements """ - return self.Adensity_matrix(None, E, kavg, isc, orbitals=orbitals, geometry=geometry) + return self.Adensity_matrix( + None, E, kavg, isc, orbitals=orbitals, geometry=geometry + ) - def Adensity_matrix(self, elec, E, kavg=True, isc=None, orbitals=None, geometry=None) -> csr_matrix: - r""" Spectral function density matrix at energy `E` (1/eV) + def Adensity_matrix( + self, elec, E, kavg=True, isc=None, orbitals=None, geometry=None + ) -> csr_matrix: + r"""Spectral function density matrix at energy `E` (1/eV) The density matrix can be used to calculate the LDOS in real-space. @@ -2053,14 +2203,16 @@ def Adensity_matrix(self, elec, E, kavg=True, isc=None, orbitals=None, geometry= DensityMatrix object containing the Geometry and the density matrix elements """ - dm = self._sparse_matrix('DM', elec, E, kavg, isc, orbitals) * eV2Ry + dm = self._sparse_matrix("DM", elec, E, kavg, isc, orbitals) * eV2Ry # Now create the density matrix object geom = self.geometry if geometry is None: DM = DensityMatrix.fromsp(geom, dm) else: if geom.no != geometry.no: - raise ValueError(f"{self.__class__.__name__}.Adensity_matrix requires input geometry to contain the correct number of orbitals. Please correct input!") + raise ValueError( + f"{self.__class__.__name__}.Adensity_matrix requires input geometry to contain the correct number of orbitals. Please correct input!" + ) DM = DensityMatrix.fromsp(geometry, dm) return DM @@ -2129,7 +2281,9 @@ def orbital_COOP(self, E, kavg=True, isc=None, orbitals=None) -> csr_matrix: """ return self.orbital_ACOOP(E, None, kavg=kavg, isc=isc, orbitals=orbitals) - def orbital_ACOOP(self, E, elec=0, kavg=True, isc=None, orbitals=None) -> csr_matrix: + def orbital_ACOOP( + self, E, elec=0, kavg=True, isc=None, orbitals=None + ) -> csr_matrix: r""" Orbital COOP analysis of the spectral function This will return a sparse matrix, see `~scipy.sparse.csr_matrix` for details. @@ -2193,10 +2347,10 @@ def orbital_ACOOP(self, E, elec=0, kavg=True, isc=None, orbitals=None) -> csr_ma orbital_ACOHP : orbital resolved COHP analysis of the spectral function atom_ACOHP : atomic COHP analysis of the spectral function """ - return self._sparse_matrix('COOP', elec, E, kavg, isc, orbitals) * eV2Ry + return self._sparse_matrix("COOP", elec, E, kavg, isc, orbitals) * eV2Ry def atom_COOP(self, E, kavg=True, isc=None, orbitals=None, uc=False) -> csr_matrix: - r""" Atomic COOP curve of the Green function + r"""Atomic COOP curve of the Green function The atomic COOP are a sum over all orbital COOP: @@ -2237,8 +2391,10 @@ def atom_COOP(self, E, kavg=True, isc=None, orbitals=None, uc=False) -> csr_matr """ return self.atom_ACOOP(E, None, kavg=kavg, isc=isc, orbitals=orbitals, uc=uc) - def atom_ACOOP(self, E, elec=0, kavg=True, isc=None, orbitals=None, uc=False) -> csr_matrix: - r""" Atomic COOP curve of the spectral function + def atom_ACOOP( + self, E, elec=0, kavg=True, isc=None, orbitals=None, uc=False + ) -> csr_matrix: + r"""Atomic COOP curve of the spectral function The atomic COOP are a sum over all orbital COOP: @@ -2285,7 +2441,7 @@ def atom_ACOOP(self, E, elec=0, kavg=True, isc=None, orbitals=None, uc=False) -> return self.sparse_orbital_to_atom(COOP, uc) def orbital_COHP(self, E, kavg=True, isc=None, orbitals=None) -> csr_matrix: - r""" Orbital resolved COHP analysis of the Green function + r"""Orbital resolved COHP analysis of the Green function This will return a sparse matrix, see ``scipy.sparse.csr_matrix`` for details. Each matrix element of the sparse matrix corresponds to the COHP of the @@ -2331,8 +2487,10 @@ def orbital_COHP(self, E, kavg=True, isc=None, orbitals=None) -> csr_matrix: """ return self.orbital_ACOHP(E, None, kavg=kavg, isc=isc, orbitals=orbitals) - def orbital_ACOHP(self, E, elec=0, kavg=True, isc=None, orbitals=None) -> csr_matrix: - r""" Orbital resolved COHP analysis of the spectral function + def orbital_ACOHP( + self, E, elec=0, kavg=True, isc=None, orbitals=None + ) -> csr_matrix: + r"""Orbital resolved COHP analysis of the spectral function This will return a sparse matrix, see ``scipy.sparse.csr_matrix`` for details. Each matrix element of the sparse matrix corresponds to the COHP of the @@ -2373,10 +2531,10 @@ def orbital_ACOHP(self, E, elec=0, kavg=True, isc=None, orbitals=None) -> csr_ma atom_COHP : atomic COHP analysis of the Green function atom_ACOHP : atomic COHP analysis of the spectral function """ - return self._sparse_matrix('COHP', elec, E, kavg, isc, orbitals) + return self._sparse_matrix("COHP", elec, E, kavg, isc, orbitals) def atom_COHP(self, E, kavg=True, isc=None, orbitals=None, uc=False) -> csr_matrix: - r""" Atomic COHP curve of the Green function + r"""Atomic COHP curve of the Green function The atomic COHP are a sum over all orbital COHP: @@ -2417,8 +2575,10 @@ def atom_COHP(self, E, kavg=True, isc=None, orbitals=None, uc=False) -> csr_matr """ return self.atom_ACOHP(E, None, kavg=kavg, isc=isc, orbitals=orbitals, uc=uc) - def atom_ACOHP(self, E, elec=0, kavg=True, isc=None, orbitals=None, uc=False) -> csr_matrix: - r""" Atomic COHP curve of the spectral function + def atom_ACOHP( + self, E, elec=0, kavg=True, isc=None, orbitals=None, uc=False + ) -> csr_matrix: + r"""Atomic COHP curve of the spectral function Parameters ---------- @@ -2458,7 +2618,7 @@ def atom_ACOHP(self, E, elec=0, kavg=True, isc=None, orbitals=None, uc=False) -> return self.sparse_orbital_to_atom(COHP, uc) def read_data(self, *args, **kwargs): - """ Read specific type of data. + """Read specific type of data. This is a generic routine for reading different parts of the data-file. @@ -2477,13 +2637,16 @@ def read_data(self, *args, **kwargs): """ val = [] for kw in kwargs: - if kw in ("geom", "geometry"): if kwargs[kw]: val.append(self.geometry) - elif kw in ("atom_current", "atom_transmission", - "vector_current", "vector_transmission"): + elif kw in ( + "atom_current", + "atom_transmission", + "vector_current", + "vector_transmission", + ): if kwargs[kw]: # TODO we need some way of handling arguments. val.append(getattr(self, kw)(*args)) @@ -2495,7 +2658,7 @@ def read_data(self, *args, **kwargs): return val def info(self, elec=None): - """ Information about the calculated quantities available for extracting in this file + """Information about the calculated quantities available for extracting in this file Parameters ---------- @@ -2507,13 +2670,14 @@ def info(self, elec=None): # Create a StringIO object to retain the information out = StringIO() + # Create wrapper function def prnt(*args, **kwargs): - option = kwargs.pop('option', None) + option = kwargs.pop("option", None) if option is None: print(*args, file=out) else: - print('{:60s}[{}]'.format(' '.join(args), ', '.join(option)), file=out) + print("{:60s}[{}]".format(" ".join(args), ", ".join(option)), file=out) def truefalse(bol, string, fdf=None): if bol: @@ -2534,14 +2698,18 @@ def truefalse(bol, string, fdf=None): nA = len(np.unique(kpt[:, 0])) nB = len(np.unique(kpt[:, 1])) nC = len(np.unique(kpt[:, 2])) - prnt((" - number of kpoints: {} <- " - "[ A = {} , B = {} , C = {} ] (time-reversal unknown)").format(self.nk, nA, nB, nC)) + prnt( + ( + " - number of kpoints: {} <- " + "[ A = {} , B = {} , C = {} ] (time-reversal unknown)" + ).format(self.nk, nA, nB, nC) + ) prnt(" - energy range:") E = self.E Em, EM = np.amin(E), np.amax(E) dE = np.diff(E) - dEm, dEM = np.amin(dE) * 1000, np.amax(dE) * 1000 # convert to meV - if (dEM - dEm) < 1e-3: # 0.001 meV + dEm, dEM = np.amin(dE) * 1000, np.amax(dE) * 1000 # convert to meV + if (dEM - dEm) < 1e-3: # 0.001 meV prnt(f" {Em:.5f} -- {EM:.5f} eV [{dEm:.3f} meV]") else: prnt(f" {Em:.5f} -- {EM:.5f} eV [{dEm:.3f} -- {dEM:.3f} meV]") @@ -2549,10 +2717,12 @@ def truefalse(bol, string, fdf=None): prnt(" - atoms with DOS (1-based):") prnt(" " + list2str(self.a_dev + 1)) prnt(" - number of BTD blocks: {}".format(self.n_btd())) - truefalse('DOS' in self.variables, "DOS Green function", ['TBT.DOS.Gf']) - truefalse('DM' in self.variables, "Density matrix Green function", ['TBT.DM.Gf']) - truefalse('COOP' in self.variables, "COOP Green function", ['TBT.COOP.Gf']) - truefalse('COHP' in self.variables, "COHP Green function", ['TBT.COHP.Gf']) + truefalse("DOS" in self.variables, "DOS Green function", ["TBT.DOS.Gf"]) + truefalse( + "DM" in self.variables, "Density matrix Green function", ["TBT.DM.Gf"] + ) + truefalse("COOP" in self.variables, "COOP Green function", ["TBT.COOP.Gf"]) + truefalse("COHP" in self.variables, "COHP Green function", ["TBT.COHP.Gf"]) if elec is None: elecs = self.elecs else: @@ -2571,34 +2741,60 @@ def truefalse(bol, string, fdf=None): try: n_btd = self.n_btd(elec) except Exception: - n_btd = 'unknown' + n_btd = "unknown" prnt() prnt(f"Electrode: {elec}") prnt(f" - number of BTD blocks: {n_btd}") prnt(" - Bloch: [{}, {}, {}]".format(*bloch)) gelec = self.groups[elec] - if 'TBT' in self._trans_type: - prnt(" - chemical potential: {:.4f} eV".format(self.chemical_potential(elec))) - prnt(" - electron temperature: {:.2f} K".format(self.electron_temperature(elec))) + if "TBT" in self._trans_type: + prnt( + " - chemical potential: {:.4f} eV".format( + self.chemical_potential(elec) + ) + ) + prnt( + " - electron temperature: {:.2f} K".format( + self.electron_temperature(elec) + ) + ) else: - prnt(" - phonon temperature: {:.4f} K".format(self.phonon_temperature(elec))) + prnt( + " - phonon temperature: {:.4f} K".format( + self.phonon_temperature(elec) + ) + ) prnt(" - imaginary part (eta): {:.4f} meV".format(self.eta(elec) * 1e3)) - truefalse('DOS' in gelec.variables, "DOS bulk", ['TBT.DOS.Elecs']) - truefalse('ADOS' in gelec.variables, "DOS spectral", ['TBT.DOS.A']) - truefalse('J' in gelec.variables, "orbital-transmission", ['TBT.Current.Orb']) - truefalse('DM' in gelec.variables, "Density matrix spectral", ['TBT.DM.A']) - truefalse('COOP' in gelec.variables, "COOP spectral", ['TBT.COOP.A']) - truefalse('COHP' in gelec.variables, "COHP spectral", ['TBT.COHP.A']) - truefalse('T' in gelec.variables, "transmission bulk", ['TBT.T.Bulk']) - truefalse(f"{elec}.T" in gelec.variables, "transmission out", ['TBT.T.Out']) - truefalse(f"{elec}.C" in gelec.variables, "transmission out correction", ['TBT.T.Out']) - truefalse(f"{elec}.C.Eig" in gelec.variables, "transmission out correction (eigen)", ['TBT.T.Out', 'TBT.T.Eig']) + truefalse("DOS" in gelec.variables, "DOS bulk", ["TBT.DOS.Elecs"]) + truefalse("ADOS" in gelec.variables, "DOS spectral", ["TBT.DOS.A"]) + truefalse( + "J" in gelec.variables, "orbital-transmission", ["TBT.Current.Orb"] + ) + truefalse("DM" in gelec.variables, "Density matrix spectral", ["TBT.DM.A"]) + truefalse("COOP" in gelec.variables, "COOP spectral", ["TBT.COOP.A"]) + truefalse("COHP" in gelec.variables, "COHP spectral", ["TBT.COHP.A"]) + truefalse("T" in gelec.variables, "transmission bulk", ["TBT.T.Bulk"]) + truefalse(f"{elec}.T" in gelec.variables, "transmission out", ["TBT.T.Out"]) + truefalse( + f"{elec}.C" in gelec.variables, + "transmission out correction", + ["TBT.T.Out"], + ) + truefalse( + f"{elec}.C.Eig" in gelec.variables, + "transmission out correction (eigen)", + ["TBT.T.Out", "TBT.T.Eig"], + ) for elec2 in self.elecs: # Skip it self, checked above in .T and .C if elec2 == elec: continue truefalse(f"{elec2}.T" in gelec.variables, f"transmission -> {elec2}") - truefalse(f"{elec2}.T.Eig" in gelec.variables, f"transmission (eigen) -> {elec2}", ['TBT.T.Eig']) + truefalse( + f"{elec2}.T.Eig" in gelec.variables, + f"transmission (eigen) -> {elec2}", + ["TBT.T.Eig"], + ) s = out.getvalue() out.close() @@ -2606,35 +2802,40 @@ def truefalse(bol, string, fdf=None): @default_ArgumentParser(description="Extract data from a TBT.nc file") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" # We limit the import to occur here import argparse - namespace = default_namespace(_tbt=self, - _geometry=self.geometry, - _data=[], _data_description=[], _data_header=[], - _norm='none', - _Ovalue='', _Orng=None, _Erng=None, - _krng=True) + namespace = default_namespace( + _tbt=self, + _geometry=self.geometry, + _data=[], + _data_description=[], + _data_header=[], + _norm="none", + _Ovalue="", + _Orng=None, + _Erng=None, + _krng=True, + ) def ensure_E(func): - """ This decorater ensures that E is the first element in the _data container """ + """This decorater ensures that E is the first element in the _data container""" def assign_E(self, *args, **kwargs): ns = args[1] if len(ns._data) == 0: # We immediately extract the energies ns._data.append(ns._tbt.E[ns._Erng].flatten()) - ns._data_header.append('Energy[eV]') + ns._data_header.append("Energy[eV]") return func(self, *args, **kwargs) + return assign_E # Correct the geometry species information class GeometryAction(argparse.Action): - def __call__(self, parser, ns, value, option_string=None): - old_g = ns._geometry.copy() # Now read the file to read the geometry from @@ -2648,13 +2849,18 @@ def __call__(self, parser, ns, value, option_string=None): g._atoms = Atoms(atoms) ns._geometry = g - p.add_argument('--geometry', '-G', - action=GeometryAction, - help=('Update the geometry of the output file, this enables one to set the species correctly,' - ' note this only affects output-files where species are important')) - class ERange(argparse.Action): + p.add_argument( + "--geometry", + "-G", + action=GeometryAction, + help=( + "Update the geometry of the output file, this enables one to set the species correctly," + " note this only affects output-files where species are important" + ), + ) + class ERange(argparse.Action): def __call__(self, parser, ns, value, option_string=None): E = ns._tbt.E Emap = strmap(float, value, E.min(), E.max()) @@ -2665,70 +2871,87 @@ def __call__(self, parser, ns, value, option_string=None): ns._Erng = None return elif begin is None: - E.append(range(ns._tbt.Eindex(end)+1)) + E.append(range(ns._tbt.Eindex(end) + 1)) elif end is None: E.append(range(ns._tbt.Eindex(begin), len(ns._tbt.E))) else: - E.append(range(ns._tbt.Eindex(begin), ns._tbt.Eindex(end)+1)) + E.append(range(ns._tbt.Eindex(begin), ns._tbt.Eindex(end) + 1)) # Issuing unique also sorts the entries ns._Erng = np.unique(_a.arrayi(E).flatten()) - p.add_argument('--energy', '-E', action=ERange, - help="""Denote the sub-section of energies that are extracted: "-1:0,1:2" [eV] - This flag takes effect on all energy-resolved quantities and is reset whenever --plot or --out is called""") + p.add_argument( + "--energy", + "-E", + action=ERange, + help="""Denote the sub-section of energies that are extracted: "-1:0,1:2" [eV] + + This flag takes effect on all energy-resolved quantities and is reset whenever --plot or --out is called""", + ) # k-range class kRange(argparse.Action): - @collect_action def __call__(self, parser, ns, value, option_string=None): try: ns._krng = int(value) except Exception: # Parse it as an array - if ',' in value: - k = map(float, value.split(',')) + if "," in value: + k = map(float, value.split(",")) else: k = map(float, value.split()) k = list(k) if len(k) != 3: - raise ValueError("Argument --kpoint *must* be an integer or 3 values to find the corresponding k-index") + raise ValueError( + "Argument --kpoint *must* be an integer or 3 values to find the corresponding k-index" + ) ns._krng = ns._tbt.kindex(k) # Add a description on which k-point this is k = ns._tbt.k[ns._krng] - ns._data_description.append('Data is extracted at k-point: [{} {} {}]'.format(k[0], k[1], k[2])) + ns._data_description.append( + "Data is extracted at k-point: [{} {} {}]".format(k[0], k[1], k[2]) + ) if not self._k_avg: - p.add_argument('--kpoint', '-k', action=kRange, - help="""Denote a specific k-index or comma/white-space separated k-point that is extracted, default to k-averaged quantity. + p.add_argument( + "--kpoint", + "-k", + action=kRange, + help="""Denote a specific k-index or comma/white-space separated k-point that is extracted, default to k-averaged quantity. For specific k-points the k weight will not be used. - This flag takes effect on all k-resolved quantities and is reset whenever --plot or --out is called""") + This flag takes effect on all k-resolved quantities and is reset whenever --plot or --out is called""", + ) # The normalization method class NormAction(argparse.Action): - @collect_action def __call__(self, parser, ns, value, option_string=None): ns._norm = value - p.add_argument('--norm', '-N', action=NormAction, default='atom', - choices=['none', 'atom', 'orbital', 'all'], - help="""Specify the normalization method; "none") no normalization, "atom") total orbitals in selected atoms, + + p.add_argument( + "--norm", + "-N", + action=NormAction, + default="atom", + choices=["none", "atom", "orbital", "all"], + help="""Specify the normalization method; "none") no normalization, "atom") total orbitals in selected atoms, "orbital") selected orbitals or "all") total orbitals in the device region. - This flag only takes effect on --dos and --ados and is reset whenever --plot or --out is called""") + This flag only takes effect on --dos and --ados and is reset whenever --plot or --out is called""", + ) # Try and add the atomic specification class AtomRange(argparse.Action): - @collect_action def __call__(self, parser, ns, value, option_string=None): - value = ",".join(# ensure only single commas (no space between them) - "".join(# ensure no empty whitespaces - ",".join(# join different lines with a comma - value.splitlines()) - .split()) - .split(",")) + value = ",".join( # ensure only single commas (no space between them) + "".join( # ensure no empty whitespaces + ",".join( # join different lines with a comma + value.splitlines() + ).split() + ).split(",") + ) # Immediately convert to proper indices geom = ns._geometry @@ -2741,15 +2964,19 @@ def __call__(self, parser, ns, value, option_string=None): # * will "only" fail if files are named accordingly, else # it will be passed as-is. # { [ * - for sep in ('b', 'c'): + for sep in ("b", "c"): try: - ranges = lstranges(strmap(int, value, a_dev.min(), a_dev.max(), sep)) + ranges = lstranges( + strmap(int, value, a_dev.min(), a_dev.max(), sep) + ) break except Exception: pass else: # only if break was not encountered - raise ValueError(f"Could not parse the atomic/orbital ranges: {value}") + raise ValueError( + f"Could not parse the atomic/orbital ranges: {value}" + ) # we have only a subset of the orbitals orbs = [] @@ -2780,11 +3007,13 @@ def __call__(self, parser, ns, value, option_string=None): orbs.append(ob) if len(orbs) == 0: - print('Device atoms:') - print(' ', list2str(a_dev)) - print('Input atoms:') - print(' ', value) - raise ValueError('Atomic/Orbital requests are not fully included in the device region.') + print("Device atoms:") + print(" ", list2str(a_dev)) + print("Input atoms:") + print(" ", value) + raise ValueError( + "Atomic/Orbital requests are not fully included in the device region." + ) # Add one to make the c-index equivalent to the f-index orbs = np.concatenate(orbs).flatten() @@ -2793,78 +3022,106 @@ def __call__(self, parser, ns, value, option_string=None): if len(orbs) != len(ns._tbt.o2p(orbs)): # This should in principle never be called because of the # checks above. - print('Device atoms:') - print(' ', list2str(a_dev)) - print('Input atoms:') - print(' ', value) - raise ValueError('Atomic/Orbital requests are not fully included in the device region.') + print("Device atoms:") + print(" ", list2str(a_dev)) + print("Input atoms:") + print(" ", value) + raise ValueError( + "Atomic/Orbital requests are not fully included in the device region." + ) ns._Ovalue = value ns._Orng = orbs - p.add_argument('--atom', '-a', type=str, action=AtomRange, - help="""Limit orbital resolved quantities to a sub-set of atoms/orbitals: "1-2[3,4]" will yield the 1st and 2nd atom and their 3rd and fourth orbital. Multiple comma-separated specifications are allowed. Note that some shells does not allow [] as text-input (due to expansion), {, [ or * are allowed orbital delimiters. + p.add_argument( + "--atom", + "-a", + type=str, + action=AtomRange, + help="""Limit orbital resolved quantities to a sub-set of atoms/orbitals: "1-2[3,4]" will yield the 1st and 2nd atom and their 3rd and fourth orbital. Multiple comma-separated specifications are allowed. Note that some shells does not allow [] as text-input (due to expansion), {, [ or * are allowed orbital delimiters. - This flag takes effect on all atom/orbital resolved quantities (except BDOS, transmission_bulk) and is reset whenever --plot or --out is called""") + This flag takes effect on all atom/orbital resolved quantities (except BDOS, transmission_bulk) and is reset whenever --plot or --out is called""", + ) class DataT(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, values, option_string=None): e1 = ns._tbt._elec(values[0]) if e1 not in ns._tbt.elecs: - raise ValueError(f"Electrode: '{e1}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e1}' cannot be found in the specified file." + ) e2 = ns._tbt._elec(values[1]) if e2 not in ns._tbt.elecs: - if e2.strip() == '.': + if e2.strip() == ".": for e2 in ns._tbt.elecs: if e2 != e1: - try: # catches if T isn't calculated + try: # catches if T isn't calculated self(parser, ns, [e1, e2], option_string) except Exception: pass return - raise ValueError(f"Electrode: '{e2}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e2}' cannot be found in the specified file." + ) # Grab the information data = ns._tbt.transmission(e1, e2, kavg=ns._krng)[ns._Erng] data.shape = (-1,) ns._data.append(data) - ns._data_header.append(f'T:{e1}-{e2}') - ns._data_description.append('Column {} is transmission from {} to {}'.format(len(ns._data), e1, e2)) - p.add_argument('-T', '--transmission', nargs=2, metavar=('ELEC1', 'ELEC2'), - action=DataT, - help='Store transmission between two electrodes.') + ns._data_header.append(f"T:{e1}-{e2}") + ns._data_description.append( + "Column {} is transmission from {} to {}".format( + len(ns._data), e1, e2 + ) + ) + + p.add_argument( + "-T", + "--transmission", + nargs=2, + metavar=("ELEC1", "ELEC2"), + action=DataT, + help="Store transmission between two electrodes.", + ) class DataBT(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, value, option_string=None): e = ns._tbt._elec(value[0]) if e not in ns._tbt.elecs: - if e.strip() == '.': + if e.strip() == ".": for e in ns._tbt.elecs: - try: # catches if B isn't calculated + try: # catches if B isn't calculated self(parser, ns, [e], option_string) except Exception: pass return - raise ValueError(f"Electrode: '{e}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e}' cannot be found in the specified file." + ) # Grab the information data = ns._tbt.transmission_bulk(e, kavg=ns._krng)[ns._Erng] data.shape = (-1,) ns._data.append(data) - ns._data_header.append(f'BT:{e}') - ns._data_description.append('Column {} is bulk-transmission'.format(len(ns._data))) - p.add_argument('-BT', '--transmission-bulk', nargs=1, metavar='ELEC', - action=DataBT, - help='Store bulk transmission of an electrode.') + ns._data_header.append(f"BT:{e}") + ns._data_description.append( + "Column {} is bulk-transmission".format(len(ns._data)) + ) + + p.add_argument( + "-BT", + "--transmission-bulk", + nargs=1, + metavar="ELEC", + action=DataBT, + help="Store bulk transmission of an electrode.", + ) class DataDOS(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, value, option_string=None): @@ -2872,70 +3129,107 @@ def __call__(self, parser, ns, value, option_string=None): # we are storing the spectral DOS e = ns._tbt._elec(value) if e not in ns._tbt.elecs: - raise ValueError(f"Electrode: '{e}' cannot be found in the specified file.") - data = ns._tbt.ADOS(e, kavg=ns._krng, orbitals=ns._Orng, norm=ns._norm) - ns._data_header.append(f'ADOS[1/eV]:{e}') + raise ValueError( + f"Electrode: '{e}' cannot be found in the specified file." + ) + data = ns._tbt.ADOS( + e, kavg=ns._krng, orbitals=ns._Orng, norm=ns._norm + ) + ns._data_header.append(f"ADOS[1/eV]:{e}") else: data = ns._tbt.DOS(kavg=ns._krng, orbitals=ns._Orng, norm=ns._norm) - ns._data_header.append('DOS[1/eV]') + ns._data_header.append("DOS[1/eV]") NORM = int(ns._tbt.norm(orbitals=ns._Orng, norm=ns._norm)) # The flatten is because when ns._Erng is None, then a new # dimension (of size 1) is created ns._data.append(data[ns._Erng].flatten()) if ns._Orng is None: - ns._data_description.append('Column {} is sum of all device atoms+orbitals with normalization 1/{}'.format(len(ns._data), NORM)) + ns._data_description.append( + "Column {} is sum of all device atoms+orbitals with normalization 1/{}".format( + len(ns._data), NORM + ) + ) else: - ns._data_description.append('Column {} is atoms[orbs] {} with normalization 1/{}'.format(len(ns._data), ns._Ovalue, NORM)) - - p.add_argument('--dos', '-D', nargs='?', metavar='ELEC', - action=DataDOS, default=None, - help="""Store DOS. If no electrode is specified, it is Green function, else it is the spectral function.""") - p.add_argument('--ados', '-AD', metavar='ELEC', - action=DataDOS, default=None, - help="""Store spectral DOS, same as --dos but requires an electrode-argument.""") + ns._data_description.append( + "Column {} is atoms[orbs] {} with normalization 1/{}".format( + len(ns._data), ns._Ovalue, NORM + ) + ) + + p.add_argument( + "--dos", + "-D", + nargs="?", + metavar="ELEC", + action=DataDOS, + default=None, + help="""Store DOS. If no electrode is specified, it is Green function, else it is the spectral function.""", + ) + p.add_argument( + "--ados", + "-AD", + metavar="ELEC", + action=DataDOS, + default=None, + help="""Store spectral DOS, same as --dos but requires an electrode-argument.""", + ) class DataDOSBulk(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, value, option_string=None): - # we are storing the Bulk DOS e = ns._tbt._elec(value[0]) if e not in ns._tbt.elecs: - raise ValueError(f"Electrode: '{e}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e}' cannot be found in the specified file." + ) # Grab the information data = ns._tbt.BDOS(e, kavg=ns._krng, sum=False) - ns._data_header.append(f'BDOS[1/eV]:{e}') + ns._data_header.append(f"BDOS[1/eV]:{e}") # Select the energies, even if _Erng is None, this will work! no = data.shape[-1] data = np.mean(data[ns._Erng, ...], axis=-1).flatten() ns._data.append(data) - ns._data_description.append('Column {} is sum of all electrode[{}] atoms+orbitals with normalization 1/{}'.format(len(ns._data), e, no)) - p.add_argument('--bulk-dos', '-BD', nargs=1, metavar='ELEC', - action=DataDOSBulk, default=None, - help="""Store bulk DOS of an electrode.""") + ns._data_description.append( + "Column {} is sum of all electrode[{}] atoms+orbitals with normalization 1/{}".format( + len(ns._data), e, no + ) + ) + + p.add_argument( + "--bulk-dos", + "-BD", + nargs=1, + metavar="ELEC", + action=DataDOSBulk, + default=None, + help="""Store bulk DOS of an electrode.""", + ) class DataTEig(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, values, option_string=None): e1 = ns._tbt._elec(values[0]) if e1 not in ns._tbt.elecs: - raise ValueError(f"Electrode: '{e1}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e1}' cannot be found in the specified file." + ) e2 = ns._tbt._elec(values[1]) if e2 not in ns._tbt.elecs: - if e2.strip() == '.': + if e2.strip() == ".": for e2 in ns._tbt.elecs: if e1 != e2: - try: # catches if T-eig isn't calculated + try: # catches if T-eig isn't calculated self(parser, ns, [e1, e2], option_string) except Exception: pass return - raise ValueError(f"Electrode: '{e2}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e2}' cannot be found in the specified file." + ) # Grab the information data = ns._tbt.transmission_eig(e1, e2, kavg=ns._krng) @@ -2943,100 +3237,136 @@ def __call__(self, parser, ns, values, option_string=None): neig = data.shape[-1] for eig in range(neig): ns._data.append(data[ns._Erng, ..., eig].flatten()) - ns._data_header.append('Teig({}):{}-{}'.format(eig+1, e1, e2)) - ns._data_description.append('Column {} is transmission eigenvalues from electrode {} to {}'.format(len(ns._data), e1, e2)) - p.add_argument('--transmission-eig', '-Teig', nargs=2, metavar=('ELEC1', 'ELEC2'), - action=DataTEig, - help='Store transmission eigenvalues between two electrodes.') + ns._data_header.append("Teig({}):{}-{}".format(eig + 1, e1, e2)) + ns._data_description.append( + "Column {} is transmission eigenvalues from electrode {} to {}".format( + len(ns._data), e1, e2 + ) + ) + + p.add_argument( + "--transmission-eig", + "-Teig", + nargs=2, + metavar=("ELEC1", "ELEC2"), + action=DataTEig, + help="Store transmission eigenvalues between two electrodes.", + ) class DataFano(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, values, option_string=None): e1 = ns._tbt._elec(values[0]) if e1 not in ns._tbt.elecs: - raise ValueError(f"Electrode: '{e1}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e1}' cannot be found in the specified file." + ) e2 = ns._tbt._elec(values[1]) if e2 not in ns._tbt.elecs: - if e2.strip() == '.': + if e2.strip() == ".": for e2 in ns._tbt.elecs: if e2 != e1: - try: # catches if T isn't calculated + try: # catches if T isn't calculated self(parser, ns, [e1, e2], option_string) except Exception: pass return - raise ValueError(f"Electrode: '{e2}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e2}' cannot be found in the specified file." + ) # Grab the information data = ns._tbt.fano(e1, e2, kavg=ns._krng)[ns._Erng] data.shape = (-1,) ns._data.append(data) - ns._data_header.append(f'Fano:{e1}-{e2}') - ns._data_description.append(f'Column {len(ns._data)} is fano-factor from {e1} to {e2}') - p.add_argument('--fano', nargs=2, metavar=('ELEC1', 'ELEC2'), - action=DataFano, - help='Store fano-factor between two electrodes.') + ns._data_header.append(f"Fano:{e1}-{e2}") + ns._data_description.append( + f"Column {len(ns._data)} is fano-factor from {e1} to {e2}" + ) + + p.add_argument( + "--fano", + nargs=2, + metavar=("ELEC1", "ELEC2"), + action=DataFano, + help="Store fano-factor between two electrodes.", + ) class DataShot(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, values, option_string=None): - classical = values[0].lower() in ('classical', 'c') + classical = values[0].lower() in ("classical", "c") e1 = ns._tbt._elec(values[1]) if e1 not in ns._tbt.elecs: - raise ValueError(f"Electrode: '{e1}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e1}' cannot be found in the specified file." + ) e2 = ns._tbt._elec(values[2]) if e2 not in ns._tbt.elecs: - if e2.strip() == '.': + if e2.strip() == ".": for e2 in ns._tbt.elecs: if e2 != e1: - try: # catches if T isn't calculated + try: # catches if T isn't calculated self(parser, ns, [values[0], e1, e2], option_string) except Exception: pass return - raise ValueError(f"Electrode: '{e2}' cannot be found in the specified file.") + raise ValueError( + f"Electrode: '{e2}' cannot be found in the specified file." + ) # Grab the information - data = ns._tbt.shot_noise(e1, e2, classical=classical, - kavg=ns._krng)[ns._Erng] + data = ns._tbt.shot_noise(e1, e2, classical=classical, kavg=ns._krng)[ + ns._Erng + ] data.shape = (-1,) ns._data.append(data) - ns._data_header.append(f'Shot:{e1}-{e2}') + ns._data_header.append(f"Shot:{e1}-{e2}") if classical: - method = 'classical' + method = "classical" else: - method = 'non-classical' - ns._data_description.append(f'Column {len(ns._data)} is {method} shot-noise from {e1} to {e2}') - p.add_argument('--shot-noise', nargs=3, metavar=('METHOD', 'ELEC1', 'ELEC2'), - action=DataShot, - help='Store shot-noise between two electrodes.') + method = "non-classical" + ns._data_description.append( + f"Column {len(ns._data)} is {method} shot-noise from {e1} to {e2}" + ) + + p.add_argument( + "--shot-noise", + nargs=3, + metavar=("METHOD", "ELEC1", "ELEC2"), + action=DataShot, + help="Store shot-noise between two electrodes.", + ) class Info(argparse.Action): - """ Action to print information contained in the TBT.nc file, helpful before performing actions """ + """Action to print information contained in the TBT.nc file, helpful before performing actions""" def __call__(self, parser, ns, value, option_string=None): # First short-hand the file print(ns._tbt.info(value)) - p.add_argument('--info', '-i', action=Info, nargs='?', metavar='ELEC', - help='Print out what information is contained in the TBT.nc file, optionally only for one of the electrodes.') + p.add_argument( + "--info", + "-i", + action=Info, + nargs="?", + metavar="ELEC", + help="Print out what information is contained in the TBT.nc file, optionally only for one of the electrodes.", + ) class Out(argparse.Action): @run_actions def __call__(self, parser, ns, value, option_string=None): - out = value[0] try: # We figure out if the user wants to write # to a geometry - obj = get_sile(out, mode='w') - if hasattr(obj, 'write_geometry'): + obj = get_sile(out, mode="w") + if hasattr(obj, "write_geometry"): with obj as fh: fh.write_geometry(ns._geometry) return @@ -3046,56 +3376,73 @@ def __call__(self, parser, ns, value, option_string=None): if len(ns._data) == 0: # do nothing if data has not been collected - print("No data has been collected in the arguments, nothing will be written, have you forgotten arguments?") + print( + "No data has been collected in the arguments, nothing will be written, have you forgotten arguments?" + ) return from sisl.io import tableSile - tableSile(out, mode='w').write(*ns._data, - comment=ns._data_description, - header=ns._data_header) + + tableSile(out, mode="w").write( + *ns._data, comment=ns._data_description, header=ns._data_header + ) # Clean all data ns._data_description = [] ns._data_header = [] ns._data = [] # These are expert options - ns._norm = 'none' - ns._Ovalue = '' + ns._norm = "none" + ns._Ovalue = "" ns._Orng = None ns._Erng = None ns._krng = True - p.add_argument('--out', '-o', nargs=1, action=Out, - help='Store currently collected information (at its current invocation) to the out file.') - class AVOut(argparse.Action): + p.add_argument( + "--out", + "-o", + nargs=1, + action=Out, + help="Store currently collected information (at its current invocation) to the out file.", + ) + class AVOut(argparse.Action): def __call__(self, parser, ns, value, option_string=None): if value is None: ns._tbt.write_tbtav() else: ns._tbt.write_tbtav(value) - p.add_argument('--tbt-av', action=AVOut, nargs='?', default=None, - help='Create "{}" with the k-averaged quantities of this file.'.format(str(self.file).replace('TBT.nc', 'TBT.AV.nc'))) - class Plot(argparse.Action): + p.add_argument( + "--tbt-av", + action=AVOut, + nargs="?", + default=None, + help='Create "{}" with the k-averaged quantities of this file.'.format( + str(self.file).replace("TBT.nc", "TBT.AV.nc") + ), + ) + class Plot(argparse.Action): @run_actions def __call__(self, parser, ns, value, option_string=None): - if len(ns._data) == 0: # do nothing if data has not been collected - print("No data has been collected in the arguments, nothing will be plotted, have you forgotten arguments?") + print( + "No data has been collected in the arguments, nothing will be plotted, have you forgotten arguments?" + ) return from matplotlib import pyplot as plt + plt.figure() def _get_header(header): - val_info = header.split(':') + val_info = header.split(":") if len(val_info) == 1: # We smiply have the data - return val_info[0].split('[')[0] + return val_info[0].split("[")[0] # We have a value *and* the electrode - return '{}:{}'.format(val_info[0].split('[')[0], val_info[1]) + return "{}:{}".format(val_info[0].split("[")[0], val_info[1]) is_DOS = True is_T = True @@ -3103,26 +3450,28 @@ def _get_header(header): is_SHOT = True is_FANO = True for i in range(1, len(ns._data)): - plt.plot(ns._data[0], ns._data[i], label=_get_header(ns._data_header[i])) - is_DOS &= 'DOS' in ns._data_header[i] - is_T &= 'T:' in ns._data_header[i] - is_Teig &= 'Teig' in ns._data_header[i] - is_SHOT &= 'Shot' in ns._data_header[i] - is_FANO &= 'Fano' in ns._data_header[i] + plt.plot( + ns._data[0], ns._data[i], label=_get_header(ns._data_header[i]) + ) + is_DOS &= "DOS" in ns._data_header[i] + is_T &= "T:" in ns._data_header[i] + is_Teig &= "Teig" in ns._data_header[i] + is_SHOT &= "Shot" in ns._data_header[i] + is_FANO &= "Fano" in ns._data_header[i] if is_DOS: - plt.ylabel('DOS [1/eV]') + plt.ylabel("DOS [1/eV]") elif is_T: - plt.ylabel('Transmission') + plt.ylabel("Transmission") elif is_Teig: - plt.ylabel('Transmission eigen') + plt.ylabel("Transmission eigen") elif is_FANO: - plt.ylabel('Fano factor') + plt.ylabel("Fano factor") elif is_SHOT: - plt.ylabel('Shot-noise') + plt.ylabel("Shot-noise") else: - plt.ylabel('mixed units') - plt.xlabel('E - E_F [eV]') + plt.ylabel("mixed units") + plt.xlabel("E - E_F [eV]") plt.legend(loc=8, ncol=3, bbox_to_anchor=(0.5, 1.0)) if value is None: @@ -3135,13 +3484,20 @@ def _get_header(header): ns._data_header = [] ns._data = [] # These are expert options - ns._norm = 'none' - ns._Ovalue = '' + ns._norm = "none" + ns._Ovalue = "" ns._Orng = None ns._Erng = None ns._krng = True - p.add_argument('--plot', '-p', action=Plot, nargs='?', metavar='FILE', - help='Plot the currently collected information (at its current invocation).') + + p.add_argument( + "--plot", + "-p", + action=Plot, + nargs="?", + metavar="FILE", + help="Plot the currently collected information (at its current invocation).", + ) return p, namespace @@ -3151,29 +3507,30 @@ def _get_header(header): # with the exception that the k-points have been averaged out. @set_module("sisl.io.tbtrans") class tbtavncSileTBtrans(tbtncSileTBtrans): - """ TBtrans average file object + """TBtrans average file object This `Sile` implements the writing of the TBtrans output ``*.TBT.AV.nc`` sile which contains the k-averaged quantities related to the NEGF code TBtrans. See `tbtncSileTBtrans` for details as this object is essentially a copy of it. """ - _trans_type = 'TBT' + + _trans_type = "TBT" _k_avg = True _E2eV = Ry2eV @property def nkpt(self): - """ Always return 1, this is to signal other routines """ + """Always return 1, this is to signal other routines""" return 1 @property def wkpt(self): - """ Always return [1.], this is to signal other routines """ + """Always return [1.], this is to signal other routines""" return _a.onesd(1) def write_tbtav(self, *args, **kwargs): - """ Wrapper for writing the k-averaged TBT.AV.nc file. + """Wrapper for writing the k-averaged TBT.AV.nc file. This write *requires* the TBT.nc `Sile` object passed as the first argument, or as the keyword ``from=tbt`` argument. @@ -3184,15 +3541,19 @@ def write_tbtav(self, *args, **kwargs): the TBT.nc file object that has the k-sampled quantities. """ - if 'from' in kwargs: - tbt = kwargs['from'] + if "from" in kwargs: + tbt = kwargs["from"] elif len(args) > 0: tbt = args[0] else: - raise SislError("tbtncSileTBtrans has not been passed to write the averaged file") + raise SislError( + "tbtncSileTBtrans has not been passed to write the averaged file" + ) if not isinstance(tbt, tbtncSileTBtrans): - raise ValueError('first argument of tbtavncSileTBtrans.write *must* be a tbtncSileTBtrans object') + raise ValueError( + "first argument of tbtavncSileTBtrans.write *must* be a tbtncSileTBtrans object" + ) # Notify if the object is not in write mode. sile_raise_write(self) @@ -3201,8 +3562,8 @@ def copy_attr(f, t): t.setncatts({att: f.getncattr(att) for att in f.ncattrs()}) # Retrieve k-weights - nkpt = len(tbt.dimensions['nkpt']) - wkpt = _a.asarrayd(tbt.variables['wkpt'][:]) + nkpt = len(tbt.dimensions["nkpt"]) + wkpt = _a.asarrayd(tbt.variables["wkpt"][:]) # First copy and re-create all entries in the output file for dvg in tbt: @@ -3228,10 +3589,9 @@ def copy_attr(f, t): grp = self.createGroup(dvg.group().path) if tbt.isDimension(dvg): - # In case the dimension is the k-point one # we remove that dimension - if 'nkpt' == dvg.name: + if "nkpt" == dvg.name: continue # Simply re-create the dimension @@ -3245,15 +3605,15 @@ def copy_attr(f, t): # It *must* be a variable now # Quickly skip the k-point variable and the weights - if dvg.name in ('kpt', 'wkpt'): + if dvg.name in ("kpt", "wkpt"): continue # Down-scale the k-point dimension - if 'nkpt' in dvg.dimensions: + if "nkpt" in dvg.dimensions: # Remove that dimension dims = list(dvg.dimensions) # Create slice - idx = dims.index('nkpt') + idx = dims.index("nkpt") dims.pop(idx) dims = tuple(dims) has_kpt = True @@ -3264,8 +3624,7 @@ def copy_attr(f, t): # We can't use dvg.filters() since it doesn't always # work... - v = grp.createVariable(dvg.name, dvg.dtype, - dimensions=dims) + v = grp.createVariable(dvg.name, dvg.dtype, dimensions=dims) # Copy attributes copy_attr(dvg, v) @@ -3293,22 +3652,22 @@ def copy_attr(f, t): v[:] = dvg[:] # Update the source attribute to signal the originating file - self.setncattr('source', 'k-average of: ' + str(tbt._file)) + self.setncattr("source", "k-average of: " + str(tbt._file)) self.sync() # Denote default writing routine _write_default = write_tbtav -for _name in ['shot_noise', 'noise_power', 'fano']: +for _name in ["shot_noise", "noise_power", "fano"]: setattr(tbtavncSileTBtrans, _name, None) -add_sile('TBT.nc', tbtncSileTBtrans) +add_sile("TBT.nc", tbtncSileTBtrans) # Add spin-dependent files -add_sile('TBT_DN.nc', tbtncSileTBtrans) -add_sile('TBT_UP.nc', tbtncSileTBtrans) -add_sile('TBT.AV.nc', tbtavncSileTBtrans) +add_sile("TBT_DN.nc", tbtncSileTBtrans) +add_sile("TBT_UP.nc", tbtncSileTBtrans) +add_sile("TBT.AV.nc", tbtavncSileTBtrans) # Add spin-dependent files -add_sile('TBT_DN.AV.nc', tbtavncSileTBtrans) -add_sile('TBT_UP.AV.nc', tbtavncSileTBtrans) +add_sile("TBT_DN.AV.nc", tbtavncSileTBtrans) +add_sile("TBT_UP.AV.nc", tbtavncSileTBtrans) diff --git a/src/sisl/io/tbtrans/tbtproj.py b/src/sisl/io/tbtrans/tbtproj.py index a75ae55f03..d00dd7831f 100644 --- a/src/sisl/io/tbtrans/tbtproj.py +++ b/src/sisl/io/tbtrans/tbtproj.py @@ -16,22 +16,23 @@ from ..sile import add_sile from .tbt import tbtncSileTBtrans -__all__ = ['tbtprojncSileTBtrans'] +__all__ = ["tbtprojncSileTBtrans"] -Bohr2Ang = unit_convert('Bohr', 'Ang') -Ry2eV = unit_convert('Ry', 'eV') -Ry2K = unit_convert('Ry', 'K') -eV2Ry = unit_convert('eV', 'Ry') +Bohr2Ang = unit_convert("Bohr", "Ang") +Ry2eV = unit_convert("Ry", "eV") +Ry2K = unit_convert("Ry", "K") +eV2Ry = unit_convert("eV", "Ry") @set_module("sisl.io.tbtrans") class tbtprojncSileTBtrans(tbtncSileTBtrans): - """ TBtrans projection file object """ - _trans_type = 'TBT.Proj' + """TBtrans projection file object""" + + _trans_type = "TBT.Proj" @classmethod def _mol_proj_elec(self, elec_mol_proj): - """ Parse the electrode-molecule-projection str/tuple into the molecule-projected-electrode + """Parse the electrode-molecule-projection str/tuple into the molecule-projected-electrode Parameters ---------- @@ -39,16 +40,18 @@ def _mol_proj_elec(self, elec_mol_proj): electrode-molecule-projection """ if isinstance(elec_mol_proj, str): - elec_mol_proj = elec_mol_proj.split('.') + elec_mol_proj = elec_mol_proj.split(".") if len(elec_mol_proj) == 1: return elec_mol_proj elif len(elec_mol_proj) != 3: - raise ValueError(f"Projection specification does not contain 3 fields: .. is required.") + raise ValueError( + f"Projection specification does not contain 3 fields: .. is required." + ) return [elec_mol_proj[i] for i in [1, 2, 0]] @property def elecs(self): - """ List of electrodes """ + """List of electrodes""" elecs = [] # in cases of not calculating all @@ -57,13 +60,13 @@ def elecs(self): for group in self.groups.keys(): if group in elecs: continue - if 'mu' in self.groups[group].variables.keys(): + if "mu" in self.groups[group].variables.keys(): elecs.append(group) return elecs @property def molecules(self): - """ List of regions where state projections may happen """ + """List of regions where state projections may happen""" mols = [] for mol in self.groups.keys(): if len(self.groups[mol].groups) > 0: @@ -72,7 +75,7 @@ def molecules(self): return mols def projections(self, molecule): - """ List of projections on `molecule` + """List of projections on `molecule` Parameters ---------- @@ -82,8 +85,17 @@ def projections(self, molecule): mol = self.groups[molecule] return list(mol.groups.keys()) - def ADOS(self, elec_mol_proj, E=None, kavg=True, atoms=None, orbitals=None, sum=True, norm='none'): - r""" Projected spectral density of states (DOS) (1/eV) + def ADOS( + self, + elec_mol_proj, + E=None, + kavg=True, + atoms=None, + orbitals=None, + sum=True, + norm="none", + ): + r"""Projected spectral density of states (DOS) (1/eV) Extract the projected spectral DOS from electrode `elec` on a selected subset of atoms/orbitals in the device region @@ -115,10 +127,19 @@ def ADOS(self, elec_mol_proj, E=None, kavg=True, atoms=None, orbitals=None, sum= how the normalization of the summed DOS is performed (see `norm` routine). """ mol_proj_elec = self._mol_proj_elec(elec_mol_proj) - return self._DOS(self._value_E('ADOS', mol_proj_elec, kavg=kavg, E=E), atoms, orbitals, sum, norm) * eV2Ry + return ( + self._DOS( + self._value_E("ADOS", mol_proj_elec, kavg=kavg, E=E), + atoms, + orbitals, + sum, + norm, + ) + * eV2Ry + ) def transmission(self, elec_mol_proj_from, elec_mol_proj_to, kavg=True): - """ Transmission from `mol_proj_elec_from` to `mol_proj_elec_to` + """Transmission from `mol_proj_elec_from` to `mol_proj_elec_to` Parameters ---------- @@ -136,11 +157,11 @@ def transmission(self, elec_mol_proj_from, elec_mol_proj_to, kavg=True): """ mol_proj_elec = self._mol_proj_elec(elec_mol_proj_from) if not isinstance(elec_mol_proj_to, str): - elec_mol_proj_to = '.'.join(elec_mol_proj_to) - return self._value_avg(elec_mol_proj_to + '.T', mol_proj_elec, kavg=kavg) + elec_mol_proj_to = ".".join(elec_mol_proj_to) + return self._value_avg(elec_mol_proj_to + ".T", mol_proj_elec, kavg=kavg) def transmission_eig(self, elec_mol_proj_from, elec_mol_proj_to, kavg=True): - """ Transmission eigenvalues from `elec_mol_proj_from` to `elec_mol_proj_to` + """Transmission eigenvalues from `elec_mol_proj_from` to `elec_mol_proj_to` Parameters ---------- @@ -158,11 +179,13 @@ def transmission_eig(self, elec_mol_proj_from, elec_mol_proj_to, kavg=True): """ mol_proj_elec = self._mol_proj_elec(elec_mol_proj_from) if not isinstance(elec_mol_proj_to, str): - elec_mol_proj_to = '.'.join(elec_mol_proj_to) - return self._value_avg(elec_mol_proj_to + '.T.Eig', mol_proj_elec, kavg=kavg) + elec_mol_proj_to = ".".join(elec_mol_proj_to) + return self._value_avg(elec_mol_proj_to + ".T.Eig", mol_proj_elec, kavg=kavg) - def Adensity_matrix(self, elec_mol_proj, E, kavg=True, isc=None, orbitals=None, geometry=None): - r""" Projected spectral function density matrix at energy `E` (1/eV) + def Adensity_matrix( + self, elec_mol_proj, E, kavg=True, isc=None, orbitals=None, geometry=None + ): + r"""Projected spectral function density matrix at energy `E` (1/eV) The projected density matrix can be used to calculate the LDOS in real-space. @@ -206,14 +229,17 @@ def Adensity_matrix(self, elec_mol_proj, E, kavg=True, isc=None, orbitals=None, DensityMatrix: the object containing the Geometry and the density matrix elements """ mol_proj_elec = self._mol_proj_elec(elec_mol_proj) - dm = self._sparse_data('DM', mol_proj_elec, E, kavg, isc, orbitals) * eV2Ry + dm = self._sparse_data("DM", mol_proj_elec, E, kavg, isc, orbitals) * eV2Ry # Now create the density matrix object geom = self.read_geometry() if geometry is None: DM = DensityMatrix.fromsp(geom, dm) else: if geom.no != geometry.no: - raise ValueError(self.__class__.__name__ + '.Adensity_matrix requires input geometry to contain the correct number of orbitals. Please correct input!') + raise ValueError( + self.__class__.__name__ + + ".Adensity_matrix requires input geometry to contain the correct number of orbitals. Please correct input!" + ) DM = DensityMatrix.fromsp(geometry, dm) return DM @@ -280,11 +306,11 @@ def orbital_ACOOP(self, elec_mol_proj, E, kavg=True, isc=None, orbitals=None): atom_ACOHP : atomic COHP analysis of the projected spectral function """ mol_proj_elec = self._mol_proj_elec(elec_mol_proj) - COOP = self._sparse_data('COOP', mol_proj_elec, E, kavg, isc, orbitals) * eV2Ry + COOP = self._sparse_data("COOP", mol_proj_elec, E, kavg, isc, orbitals) * eV2Ry return COOP def orbital_ACOHP(self, elec_mol_proj, E, kavg=True, isc=None, orbitals=None): - r""" Orbital COHP analysis of the projected spectral function + r"""Orbital COHP analysis of the projected spectral function This will return a sparse matrix, see ``scipy.sparse.csr_matrix`` for details. Each matrix element of the sparse matrix corresponds to the COHP of the @@ -324,65 +350,91 @@ def orbital_ACOHP(self, elec_mol_proj, E, kavg=True, isc=None, orbitals=None): atom_ACOOP : atomic COOP analysis of the projected spectral function """ mol_proj_elec = self._mol_proj_elec(elec_mol_proj) - COHP = self._sparse_data('COHP', mol_proj_elec, E, kavg, isc, orbitals) + COHP = self._sparse_data("COHP", mol_proj_elec, E, kavg, isc, orbitals) return COHP @default_ArgumentParser(description="Extract data from a TBT.Proj.nc file") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" p, namespace = super().ArgumentParser(p, *args, **kwargs) # We limit the import to occur here import argparse def ensure_E(func): - """ This decorater ensures that E is the first element in the _data container """ + """This decorater ensures that E is the first element in the _data container""" def assign_E(self, *args, **kwargs): ns = args[1] if len(ns._data) == 0: # We immediately extract the energies ns._data.append(ns._tbt.E[ns._Erng].flatten()) - ns._data_header.append('Energy[eV]') + ns._data_header.append("Energy[eV]") return func(self, *args, **kwargs) + return assign_E class InfoMols(argparse.Action): def __call__(self, parser, ns, value, option_string=None): - print(' '.join(ns._tbt.molecules)) - p.add_argument('--molecules', '-M', nargs=0, - action=InfoMols, - help="""Show molecules in the projection file""") + print(" ".join(ns._tbt.molecules)) + + p.add_argument( + "--molecules", + "-M", + nargs=0, + action=InfoMols, + help="""Show molecules in the projection file""", + ) class InfoProjs(argparse.Action): def __call__(self, parser, ns, value, option_string=None): - print(' '.join(ns._tbt.projections(value[0]))) - p.add_argument('--projections', '-P', nargs=1, metavar='MOL', - action=InfoProjs, - help="""Show projections on molecule.""") + print(" ".join(ns._tbt.projections(value[0]))) - class DataDOS(argparse.Action): + p.add_argument( + "--projections", + "-P", + nargs=1, + metavar="MOL", + action=InfoProjs, + help="""Show projections on molecule.""", + ) + class DataDOS(argparse.Action): @collect_action @ensure_E def __call__(self, parser, ns, value, option_string=None): - data = ns._tbt.ADOS(value, kavg=ns._krng, orbitals=ns._Orng, norm=ns._norm) - ns._data_header.append(f'ADOS[1/eV]:{value}') + data = ns._tbt.ADOS( + value, kavg=ns._krng, orbitals=ns._Orng, norm=ns._norm + ) + ns._data_header.append(f"ADOS[1/eV]:{value}") NORM = int(ns._tbt.norm(orbitals=ns._Orng, norm=ns._norm)) # The flatten is because when ns._Erng is None, then a new # dimension (of size 1) is created ns._data.append(data[ns._Erng].flatten()) if ns._Orng is None: - ns._data_description.append('Column {} is sum of all device atoms+orbitals with normalization 1/{}'.format(len(ns._data), NORM)) + ns._data_description.append( + "Column {} is sum of all device atoms+orbitals with normalization 1/{}".format( + len(ns._data), NORM + ) + ) else: - ns._data_description.append('Column {} is atoms[orbs] {} with normalization 1/{}'.format(len(ns._data), ns._Ovalue, NORM)) - p.add_argument('--ados', '-AD', metavar='E.M.P', - action=DataDOS, default=None, - help="""Store projected spectral DOS""") + ns._data_description.append( + "Column {} is atoms[orbs] {} with normalization 1/{}".format( + len(ns._data), ns._Ovalue, NORM + ) + ) + + p.add_argument( + "--ados", + "-AD", + metavar="E.M.P", + action=DataDOS, + default=None, + help="""Store projected spectral DOS""", + ) class DataT(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, values, option_string=None): @@ -390,17 +442,28 @@ def __call__(self, parser, ns, values, option_string=None): elec_mol_proj2 = values[1] # Grab the information - data = ns._tbt.transmission(elec_mol_proj1, elec_mol_proj2, kavg=ns._krng)[ns._Erng] + data = ns._tbt.transmission( + elec_mol_proj1, elec_mol_proj2, kavg=ns._krng + )[ns._Erng] data.shape = (-1,) ns._data.append(data) - ns._data_header.append(f'T:{elec_mol_proj1}-{elec_mol_proj2}') - ns._data_description.append('Column {} is transmission from {} to {}'.format(len(ns._data), elec_mol_proj1, elec_mol_proj2)) - p.add_argument('-T', '--transmission', nargs=2, metavar=('E.M.P1', 'E.M.P2'), - action=DataT, - help='Store transmission between two projections.') + ns._data_header.append(f"T:{elec_mol_proj1}-{elec_mol_proj2}") + ns._data_description.append( + "Column {} is transmission from {} to {}".format( + len(ns._data), elec_mol_proj1, elec_mol_proj2 + ) + ) + + p.add_argument( + "-T", + "--transmission", + nargs=2, + metavar=("E.M.P1", "E.M.P2"), + action=DataT, + help="Store transmission between two projections.", + ) class DataTEig(argparse.Action): - @collect_action @ensure_E def __call__(self, parser, ns, values, option_string=None): @@ -408,20 +471,34 @@ def __call__(self, parser, ns, values, option_string=None): elec_mol_proj2 = values[1] # Grab the information - data = ns._tbt.transmission_eig(elec_mol_proj1, elec_mol_proj2, kavg=ns._krng)[ns._Erng] + data = ns._tbt.transmission_eig( + elec_mol_proj1, elec_mol_proj2, kavg=ns._krng + )[ns._Erng] neig = data.shape[-1] for eig in range(neig): ns._data.append(data[ns._Erng, ..., eig].flatten()) - ns._data_header.append('Teig({}):{}-{}'.format(eig+1, elec_mol_proj1, elec_mol_proj2)) - ns._data_description.append('Column {} is transmission eigenvalues from electrode {} to {}'.format(len(ns._data), elec_mol_proj1, elec_mol_proj2)) - p.add_argument('-Teig', '--transmission-eig', nargs=2, metavar=('E.M.P1', 'E.M.P2'), - action=DataTEig, - help='Store transmission eigenvalues between two projections.') + ns._data_header.append( + "Teig({}):{}-{}".format(eig + 1, elec_mol_proj1, elec_mol_proj2) + ) + ns._data_description.append( + "Column {} is transmission eigenvalues from electrode {} to {}".format( + len(ns._data), elec_mol_proj1, elec_mol_proj2 + ) + ) + + p.add_argument( + "-Teig", + "--transmission-eig", + nargs=2, + metavar=("E.M.P1", "E.M.P2"), + action=DataTEig, + help="Store transmission eigenvalues between two projections.", + ) return p, namespace def info(self, molecule=None): - """ Information about the calculated quantities available for extracting in this file + """Information about the calculated quantities available for extracting in this file Parameters ---------- @@ -430,16 +507,17 @@ def info(self, molecule=None): """ # Create a StringIO object to retain the information out = StringIO() + # Create wrapper function def prnt(*args, **kwargs): - option = kwargs.pop('option', None) + option = kwargs.pop("option", None) if option is None: print(*args, file=out) else: - print('{:70s}[{}]'.format(' '.join(args), ', '.join(option)), file=out) + print("{:70s}[{}]".format(" ".join(args), ", ".join(option)), file=out) def true(string, fdf=None, suf=2): - prnt("{}+ {}: true".format(' ' * suf, string), option=fdf) + prnt("{}+ {}: true".format(" " * suf, string), option=fdf) # Retrieve the device atoms prnt("Device information:") @@ -454,14 +532,18 @@ def true(string, fdf=None, suf=2): nA = len(np.unique(kpt[:, 0])) nB = len(np.unique(kpt[:, 1])) nC = len(np.unique(kpt[:, 2])) - prnt((" - number of kpoints: {} <- " - "[ A = {} , B = {} , C = {} ] (time-reversal unknown)").format(self.nk, nA, nB, nC)) + prnt( + ( + " - number of kpoints: {} <- " + "[ A = {} , B = {} , C = {} ] (time-reversal unknown)" + ).format(self.nk, nA, nB, nC) + ) prnt(" - energy range:") E = self.E Em, EM = np.amin(E), np.amax(E) dE = np.diff(E) - dEm, dEM = np.amin(dE) * 1000, np.amax(dE) * 1000 # convert to meV - if (dEM - dEm) < 1e-3: # 0.001 meV + dEm, dEM = np.amin(dE) * 1000, np.amax(dE) * 1000 # convert to meV + if (dEM - dEm) < 1e-3: # 0.001 meV prnt(f" {Em:.5f} -- {EM:.5f} eV [{dEm:.3f} meV]") else: prnt(f" {Em:.5f} -- {EM:.5f} eV [{dEm:.3f} -- {dEM:.3f} meV]") @@ -479,7 +561,7 @@ def _get_all(opt, vars): indices = [] for i, var in enumerate(vars): if var.endswith(opt): - out.append(var[:-len(opt)]) + out.append(var[: -len(opt)]) indices.append(i) indices.sort(reverse=True) for i in indices: @@ -487,7 +569,7 @@ def _get_all(opt, vars): return out def _print_to(ns, var): - elec_mol_proj = var.split('.') + elec_mol_proj = var.split(".") if len(elec_mol_proj) == 1: prnt(" " * ns + "-> {elec}".format(elec=elec_mol_proj[0])) elif len(elec_mol_proj) == 3: @@ -502,81 +584,83 @@ def _print_to_full(s, vars): for var in vars: _print_to(ns, var) - eig_kwargs = {'precision': 4, 'threshold': 1e6, 'suffix': '', 'prefix': ''} + eig_kwargs = {"precision": 4, "threshold": 1e6, "suffix": "", "prefix": ""} # Print out information for each electrode for mol in mols: - opt = {'mol1': mol} + opt = {"mol1": mol} gmol = self.groups[mol] prnt() prnt(f"Molecule: {mol}") prnt(" - molecule atoms (1-based):") - prnt(" " + list2str(gmol.variables['atom'][:])) + prnt(" " + list2str(gmol.variables["atom"][:])) # molecule states and eigenvalues stored - lvls = gmol.variables['lvl'][:] + lvls = gmol.variables["lvl"][:] lvls = np.where(lvls < 0, lvls + 1, lvls) + gmol.HOMO_index - eigs = gmol.variables['eig'][:] * Ry2eV + eigs = gmol.variables["eig"][:] * Ry2eV prnt(f" - state indices (1-based) (total={lvls.size}):") prnt(" " + list2str(lvls)) prnt(" - state eigenvalues (eV):") - prnt(" " + np.array2string(eigs[lvls-1], **eig_kwargs)[1:-1]) + prnt(" " + np.array2string(eigs[lvls - 1], **eig_kwargs)[1:-1]) projs = self.projections(mol) prnt(" - number of projections: {}".format(len(projs))) for proj in projs: - opt['proj1'] = proj + opt["proj1"] = proj gproj = gmol.groups[proj] prnt(" > Projection: {mol1}.{proj1}".format(**opt)) # Also pretty print the eigenvalues associated with these - lvls = gproj.variables['lvl'][:] + lvls = gproj.variables["lvl"][:] lvls = np.where(lvls < 0, lvls + 1, lvls) + gmol.HOMO_index prnt(f" - state indices (1-based) (total={lvls.size}):") prnt(" " + list2str(lvls)) prnt(" - state eigenvalues:") - prnt(" " + np.array2string(eigs[lvls-1], **eig_kwargs)[1:-1]) + prnt(" " + np.array2string(eigs[lvls - 1], **eig_kwargs)[1:-1]) # Figure out the electrode projections elecs = gproj.groups.keys() for elec in elecs: - opt['elec1'] = elec + opt["elec1"] = elec gelec = gproj.groups[elec] - vars = list(gelec.variables.keys()) # ensure a copy + vars = list(gelec.variables.keys()) # ensure a copy prnt(" > Electrode: {elec1}.{mol1}.{proj1}".format(**opt)) # Loop and figure out what is in it. - if 'ADOS' in vars: - vars.pop(vars.index('ADOS')) - true("DOS spectral", ['TBT.Projs.DOS.A'], suf=8) - if 'J' in vars: - vars.pop(vars.index('J')) - true("orbital-current", ['TBT.Projs.Current.Orb'], suf=8) - if 'DM' in vars: - vars.pop(vars.index('DM')) - true("Density matrix spectral", ['TBT.Projs.DM.A'], suf=8) - if 'COOP' in vars: - vars.pop(vars.index('COOP')) - true("COOP spectral", ['TBT.Projs.COOP.A'], suf=8) - if 'COHP' in vars: - vars.pop(vars.index('COHP')) - true("COHP spectral", ['TBT.Projs.COHP.A'], suf=8) + if "ADOS" in vars: + vars.pop(vars.index("ADOS")) + true("DOS spectral", ["TBT.Projs.DOS.A"], suf=8) + if "J" in vars: + vars.pop(vars.index("J")) + true("orbital-current", ["TBT.Projs.Current.Orb"], suf=8) + if "DM" in vars: + vars.pop(vars.index("DM")) + true("Density matrix spectral", ["TBT.Projs.DM.A"], suf=8) + if "COOP" in vars: + vars.pop(vars.index("COOP")) + true("COOP spectral", ["TBT.Projs.COOP.A"], suf=8) + if "COHP" in vars: + vars.pop(vars.index("COHP")) + true("COHP spectral", ["TBT.Projs.COHP.A"], suf=8) # Retrieve all vars with transmissions - vars_T = _get_all('.T', vars) - vars_Teig = _get_all('.T.Eig', vars) - vars_C = _get_all('.C', vars) - vars_Ceig = _get_all('.C.Eig', vars) + vars_T = _get_all(".T", vars) + vars_Teig = _get_all(".T.Eig", vars) + vars_C = _get_all(".C", vars) + vars_Ceig = _get_all(".C.Eig", vars) _print_to_full(" + transmission:", vars_T) _print_to_full(" + transmission (eigen):", vars_Teig) _print_to_full(" + transmission out corr.:", vars_C) - _print_to_full(" + transmission out corr. (eigen):", vars_Ceig) + _print_to_full( + " + transmission out corr. (eigen):", vars_Ceig + ) # Finally there may be only RHS projections in which case the remaining groups are for # *pristine* electrodes for elec in self.elecs: gelec = self.groups[elec] - vars = list(gelec.variables.keys()) # ensure a copy + vars = list(gelec.variables.keys()) # ensure a copy try: bloch = self.bloch(elec) @@ -585,23 +669,35 @@ def _print_to_full(s, vars): try: n_btd = self.n_btd(elec) except Exception: - n_btd = 'unknown' + n_btd = "unknown" prnt() prnt(f"Electrode: {elec}") prnt(f" - number of BTD blocks: {n_btd}") prnt(" - Bloch: [{}, {}, {}]".format(*bloch)) - if 'TBT' in self._trans_type: - prnt(" - chemical potential: {:.4f} eV".format(self.chemical_potential(elec))) - prnt(" - electron temperature: {:.2f} K".format(self.electron_temperature(elec))) + if "TBT" in self._trans_type: + prnt( + " - chemical potential: {:.4f} eV".format( + self.chemical_potential(elec) + ) + ) + prnt( + " - electron temperature: {:.2f} K".format( + self.electron_temperature(elec) + ) + ) else: - prnt(" - phonon temperature: {:.4f} K".format(self.phonon_temperature(elec))) + prnt( + " - phonon temperature: {:.4f} K".format( + self.phonon_temperature(elec) + ) + ) prnt(" - imaginary part (eta): {:.4f} meV".format(self.eta(elec) * 1e3)) # Retrieve all vars with transmissions - vars_T = _get_all('.T', vars) - vars_Teig = _get_all('.T.Eig', vars) - vars_C = _get_all('.C', vars) - vars_Ceig = _get_all('.C.Eig', vars) + vars_T = _get_all(".T", vars) + vars_Teig = _get_all(".T.Eig", vars) + vars_C = _get_all(".C", vars) + vars_Ceig = _get_all(".C.Eig", vars) _print_to_full(" + transmission:", vars_T) _print_to_full(" + transmission (eigen):", vars_Teig) @@ -613,7 +709,7 @@ def _print_to_full(s, vars): return s def eigenstate(self, molecule, k=None, all=True): - r""" Return the eigenstate on the projected `molecule` + r"""Return the eigenstate on the projected `molecule` The eigenstate object will contain the geometry as the parent object. The eigenstate will be in the Lowdin basis: @@ -634,46 +730,55 @@ def eigenstate(self, molecule, k=None, all=True): ------- EigenstateElectron """ - if 'PHT' in self._trans_type: + if "PHT" in self._trans_type: from sisl.physics import EigenmodePhonon as cls else: from sisl.physics import EigenstateElectron as cls mol = self.groups[molecule] - if all and ('states' in mol.variables or 'Restates' in mol.variables): - suf = 'states' + if all and ("states" in mol.variables or "Restates" in mol.variables): + suf = "states" else: all = False - suf = 'state' + suf = "state" is_gamma = suf in mol.variables if is_gamma: state = mol.variables[suf][:] else: - state = mol.variables['Re' + suf][:] + 1j * mol.variables['Im' + suf][:] - eig = mol.variables['eig'][:] + state = mol.variables["Re" + suf][:] + 1j * mol.variables["Im" + suf][:] + eig = mol.variables["eig"][:] if eig.ndim > 1: - raise NotImplementedError(self.__class__.__name__ + ".eigenstate currently does not implement " - "the k-point version.") + raise NotImplementedError( + self.__class__.__name__ + ".eigenstate currently does not implement " + "the k-point version." + ) geom = self.read_geometry() if all: return cls(state, eig, parent=geom) - lvl = mol.variables['lvl'][:] + lvl = mol.variables["lvl"][:] lvl = np.where(lvl > 0, lvl - 1, lvl) + mol.HOMO_index return cls(state, eig[lvl], parent=geom) -for _name in ['current', 'current_parameter', - 'shot_noise', 'noise_power', 'fano', - 'density_matrix', - 'orbital_COOP', 'atom_COOP', - 'orbital_COHP', 'atom_COHP']: +for _name in [ + "current", + "current_parameter", + "shot_noise", + "noise_power", + "fano", + "density_matrix", + "orbital_COOP", + "atom_COOP", + "orbital_COHP", + "atom_COHP", +]: setattr(tbtprojncSileTBtrans, _name, None) -add_sile('TBT.Proj.nc', tbtprojncSileTBtrans) +add_sile("TBT.Proj.nc", tbtprojncSileTBtrans) # Add spin-dependent files -add_sile('TBT_DN.Proj.nc', tbtprojncSileTBtrans) -add_sile('TBT_UP.Proj.nc', tbtprojncSileTBtrans) +add_sile("TBT_DN.Proj.nc", tbtprojncSileTBtrans) +add_sile("TBT_UP.Proj.nc", tbtprojncSileTBtrans) diff --git a/src/sisl/io/tbtrans/tests/test_delta.py b/src/sisl/io/tbtrans/tests/test_delta.py index 1ea339190a..f1c9c06c2f 100644 --- a/src/sisl/io/tbtrans/tests/test_delta.py +++ b/src/sisl/io/tbtrans/tests/test_delta.py @@ -10,151 +10,148 @@ from sisl.io.tbtrans import * pytestmark = [pytest.mark.io, pytest.mark.tbtrans] -_dir = osp.join('sisl', 'io', 'tbtrans') +_dir = osp.join("sisl", "io", "tbtrans") netCDF4 = pytest.importorskip("netCDF4") def test_tbt_delta1(sisl_tmp, sisl_system): - f = sisl_tmp('gr.dH.nc', _dir) + f = sisl_tmp("gr.dH.nc", _dir) H = Hamiltonian(sisl_system.gtb) H.construct([sisl_system.R, sisl_system.t]) # annoyingly this has to be performed like this... - with deltancSileTBtrans(f, 'w') as sile: + with deltancSileTBtrans(f, "w") as sile: H.geometry.write(sile) - with deltancSileTBtrans(f, 'a') as sile: - + with deltancSileTBtrans(f, "a") as sile: # Write to level-1 sile.write_delta(H) # Write to level-2 - sile.write_delta(H, k=[0, 0, .5]) - assert sile._get_lvl(2).variables['kpt'].shape == (1, 3) + sile.write_delta(H, k=[0, 0, 0.5]) + assert sile._get_lvl(2).variables["kpt"].shape == (1, 3) # Write to level-3 sile.write_delta(H, E=0.1) - assert sile._get_lvl(3).variables['E'].shape == (1, ) + assert sile._get_lvl(3).variables["E"].shape == (1,) sile.write_delta(H, E=0.2) - assert sile._get_lvl(3).variables['E'].shape == (2, ) + assert sile._get_lvl(3).variables["E"].shape == (2,) # Write to level-4 - sile.write_delta(H, E=0.1, k=[0, 0, .5]) - assert sile._get_lvl(4).variables['kpt'].shape == (1, 3) - assert sile._get_lvl(4).variables['E'].shape == (1, ) - sile.write_delta(H, E=0.2, k=[0, 0, .5]) - assert sile._get_lvl(4).variables['kpt'].shape == (1, 3) - assert sile._get_lvl(4).variables['E'].shape == (2, ) - sile.write_delta(H, E=0.2, k=[0, 1., .5]) - assert sile._get_lvl(4).variables['kpt'].shape == (2, 3) - assert sile._get_lvl(4).variables['E'].shape == (2, ) - - with deltancSileTBtrans(f, 'r') as sile: - + sile.write_delta(H, E=0.1, k=[0, 0, 0.5]) + assert sile._get_lvl(4).variables["kpt"].shape == (1, 3) + assert sile._get_lvl(4).variables["E"].shape == (1,) + sile.write_delta(H, E=0.2, k=[0, 0, 0.5]) + assert sile._get_lvl(4).variables["kpt"].shape == (1, 3) + assert sile._get_lvl(4).variables["E"].shape == (2,) + sile.write_delta(H, E=0.2, k=[0, 1.0, 0.5]) + assert sile._get_lvl(4).variables["kpt"].shape == (2, 3) + assert sile._get_lvl(4).variables["E"].shape == (2,) + + with deltancSileTBtrans(f, "r") as sile: # Read to level-1 h = sile.read_delta() assert h.spsame(H) # Read level-2 - h = sile.read_delta(k=[0, 0, .5]) + h = sile.read_delta(k=[0, 0, 0.5]) assert h.spsame(H) # Read level-3 h = sile.read_delta(E=0.1) assert h.spsame(H) # Read level-4 - h = sile.read_delta(E=0.1, k=[0, 0, .5]) + h = sile.read_delta(E=0.1, k=[0, 0, 0.5]) assert h.spsame(H) - h = sile.read_delta(E=0.1, k=[0, 0., .5]) + h = sile.read_delta(E=0.1, k=[0, 0.0, 0.5]) assert h.spsame(H) - h = sile.read_delta(E=0.2, k=[0, 1., .5]) + h = sile.read_delta(E=0.2, k=[0, 1.0, 0.5]) assert h.spsame(H) def test_tbt_delta_fail(sisl_tmp, sisl_system): - f = sisl_tmp('gr.dH.nc', _dir) + f = sisl_tmp("gr.dH.nc", _dir) H = Hamiltonian(sisl_system.gtb) H.construct([sisl_system.R, sisl_system.t]) H.finalize() - with deltancSileTBtrans(f, 'w') as sile: - sile.write_delta(H, k=[0.] * 3) + with deltancSileTBtrans(f, "w") as sile: + sile.write_delta(H, k=[0.0] * 3) for i in range(H.no_s): - H[0, i] = 1. + H[0, i] = 1.0 with pytest.raises(ValueError): sile.write_delta(H, k=[0.2] * 3) def test_tbt_delta_write_read(sisl_tmp, sisl_system): - f = sisl_tmp('gr.dH.nc', _dir) + f = sisl_tmp("gr.dH.nc", _dir) H = Hamiltonian(sisl_system.gtb, dtype=np.complex64) H.construct([sisl_system.R, sisl_system.t]) H.finalize() - with deltancSileTBtrans(f, 'w') as sile: + with deltancSileTBtrans(f, "w") as sile: sile.write_delta(H) - with deltancSileTBtrans(f, 'r') as sile: + with deltancSileTBtrans(f, "r") as sile: h = sile.read_delta() assert h.spsame(H) assert h.dkind == H.dkind def test_tbt_delta_fail_list_col(sisl_tmp, sisl_system): - f = sisl_tmp('gr.dH.nc', _dir) + f = sisl_tmp("gr.dH.nc", _dir) H = Hamiltonian(sisl_system.gtb) H.construct([sisl_system.R, sisl_system.t]) - with deltancSileTBtrans(f, 'w') as sile: - sile.write_delta(H, E=-1.) + with deltancSileTBtrans(f, "w") as sile: + sile.write_delta(H, E=-1.0) edges = H.edges(0) i = edges.max() + 1 del H[0, i - 1] - H[0, i] = 1. + H[0, i] = 1.0 with pytest.raises(ValueError): - sile.write_delta(H, E=1.) + sile.write_delta(H, E=1.0) def test_tbt_delta_fail_ncol(sisl_tmp, sisl_system): - f = sisl_tmp('gr.dH.nc', _dir) + f = sisl_tmp("gr.dH.nc", _dir) H = Hamiltonian(sisl_system.gtb) H.construct([sisl_system.R, sisl_system.t]) - with deltancSileTBtrans(f, 'w') as sile: - sile.write_delta(H, E=-1.) + with deltancSileTBtrans(f, "w") as sile: + sile.write_delta(H, E=-1.0) edges = H.edges(0) i = edges.max() + 1 - H[0, i] = 1. + H[0, i] = 1.0 H.finalize() with pytest.raises(ValueError): - sile.write_delta(H, E=1.) + sile.write_delta(H, E=1.0) def test_tbt_delta_merge(sisl_tmp, sisl_system): - f1 = sisl_tmp('gr1.dH.nc', _dir) - f2 = sisl_tmp('gr2.dH.nc', _dir) - fout = sisl_tmp('grmerged.dH.nc', _dir) - + f1 = sisl_tmp("gr1.dH.nc", _dir) + f2 = sisl_tmp("gr2.dH.nc", _dir) + fout = sisl_tmp("grmerged.dH.nc", _dir) H = Hamiltonian(sisl_system.gtb) H.construct([sisl_system.R, sisl_system.t]) H.finalize() - with deltancSileTBtrans(f1, 'w') as sile: - sile.write_delta(H, E=-1.) - sile.write_delta(H, E=-1., k=[0, 1, 1]) + with deltancSileTBtrans(f1, "w") as sile: + sile.write_delta(H, E=-1.0) + sile.write_delta(H, E=-1.0, k=[0, 1, 1]) sile.write_delta(H) sile.write_delta(H, k=[0, 1, 0]) - with deltancSileTBtrans(f2, 'w') as sile: - sile.write_delta(H, E=-1.) - sile.write_delta(H, E=-1., k=[0, 1, 1]) + with deltancSileTBtrans(f2, "w") as sile: + sile.write_delta(H, E=-1.0) + sile.write_delta(H, E=-1.0, k=[0, 1, 1]) sile.write_delta(H) sile.write_delta(H, k=[0, 1, 0]) # Now merge them deltancSileTBtrans.merge(fout, deltancSileTBtrans(f1), f2) - with deltancSileTBtrans(fout, 'r') as sile: + with deltancSileTBtrans(fout, "r") as sile: h = sile.read_delta() / 2 assert h.spsame(H) - h = sile.read_delta(E=-1.) / 2 + h = sile.read_delta(E=-1.0) / 2 assert h.spsame(H) - h = sile.read_delta(E=-1., k=[0, 1, 1]) / 2 + h = sile.read_delta(E=-1.0, k=[0, 1, 1]) / 2 assert h.spsame(H) h = sile.read_delta(k=[0, 1, 0]) / 2 assert h.spsame(H) diff --git a/src/sisl/io/tbtrans/tests/test_tbt.py b/src/sisl/io/tbtrans/tests/test_tbt.py index c804675075..b4a52f7a88 100644 --- a/src/sisl/io/tbtrans/tests/test_tbt.py +++ b/src/sisl/io/tbtrans/tests/test_tbt.py @@ -12,7 +12,7 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.tbtrans] -_dir = osp.join('sisl', 'io', 'tbtrans') +_dir = osp.join("sisl", "io", "tbtrans") netCDF4 = pytest.importorskip("netCDF4") @@ -20,7 +20,7 @@ @pytest.mark.slow @pytest.mark.filterwarnings("ignore", message="*.*.o2p") def test_1_graphene_all_content(sisl_files): - """ This tests manifolds itself as: + """This tests manifolds itself as: sisl.geom.graphene(orthogonal=True).tile(3, 0).tile(5, 1) @@ -49,9 +49,9 @@ def test_1_graphene_all_content(sisl_files): TBT.k [100 1 1] ### FDF ### """ - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) - assert tbt.E.min() > -2. - assert tbt.E.max() < 2. + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) + assert tbt.E.min() > -2.0 + assert tbt.E.max() < 2.0 # We have 400 energy-points ne = len(tbt.E) assert ne == 400 @@ -61,7 +61,7 @@ def test_1_graphene_all_content(sisl_files): nk = len(tbt.kpt) assert nk == 100 assert tbt.nk == nk - assert tbt.wk.sum() == pytest.approx(1.) + assert tbt.wk.sum() == pytest.approx(1.0) for i in range(ne): assert tbt.Eindex(i) == i @@ -69,13 +69,13 @@ def test_1_graphene_all_content(sisl_files): # Check raises with pytest.warns(sisl.SislWarning): - tbt.Eindex(tbt.E.min() - 1.) + tbt.Eindex(tbt.E.min() - 1.0) with pytest.warns(sisl.SislInfo): tbt.Eindex(tbt.E.min() - 2e-3) with pytest.warns(sisl.SislWarning): tbt.kindex([0, 0, 0.5]) # Can't hit it - #with pytest.warns(sisl.SislInfo): + # with pytest.warns(sisl.SislInfo): # tbt.kindex([0.0106, 0, 0]) for i in range(nk): @@ -97,8 +97,10 @@ def test_1_graphene_all_content(sisl_files): # Check device atoms (1-orbital system) assert tbt.na_d == tbt.no_d - assert tbt.na_d == 36 # 3 * 5 * 4 (and device is without electrodes, so 3 * 3 * 4) - assert len(tbt.pivot()) == 3 * 3 * 4 # 3 * 5 * 4 (and device is without electrodes, so 3 * 3 * 4) + assert tbt.na_d == 36 # 3 * 5 * 4 (and device is without electrodes, so 3 * 3 * 4) + assert ( + len(tbt.pivot()) == 3 * 3 * 4 + ) # 3 * 5 * 4 (and device is without electrodes, so 3 * 3 * 4) assert len(tbt.pivot(in_device=True)) == len(tbt.pivot()) assert np.all(tbt.pivot(in_device=True, sort=True) == np.arange(tbt.no_d)) assert np.all(tbt.pivot(sort=True) == np.sort(tbt.pivot())) @@ -109,15 +111,15 @@ def test_1_graphene_all_content(sisl_files): # Check electrodes assert len(tbt.elecs) == 2 elecs = tbt.elecs[:] - assert elecs == ['Left', 'Right'] + assert elecs == ["Left", "Right"] for i, elec in enumerate(elecs): assert tbt._elec(i) == elec # Check the chemical potentials for elec in elecs: assert tbt.n_btd(elec) == len(tbt.btd(elec)) - assert tbt.chemical_potential(elec) == pytest.approx(0.) - assert tbt.electron_temperature(elec) == pytest.approx(300., abs=1) + assert tbt.chemical_potential(elec) == pytest.approx(0.0) + assert tbt.electron_temperature(elec) == pytest.approx(300.0, abs=1) assert tbt.eta(elec) == pytest.approx(1e-4, abs=1e-6) # Check electrode relevant stuff @@ -125,13 +127,19 @@ def test_1_graphene_all_content(sisl_files): right = elecs[1] # Assert we have transmission symmetry - assert np.allclose(tbt.transmission(left, right), - tbt.transmission(right, left)) - assert np.allclose(tbt.transmission_eig(left, right), - tbt.transmission_eig(right, left)) + assert np.allclose(tbt.transmission(left, right), tbt.transmission(right, left)) + assert np.allclose( + tbt.transmission_eig(left, right), tbt.transmission_eig(right, left) + ) # Check that the total transmission is larger than the sum of transmission eigenvalues - assert np.all(tbt.transmission(left, right) + 1e-7 >= tbt.transmission_eig(left, right).sum(-1)) - assert np.all(tbt.transmission(right, left) + 1e-7 >= tbt.transmission_eig(right, left).sum(-1)) + assert np.all( + tbt.transmission(left, right) + 1e-7 + >= tbt.transmission_eig(left, right).sum(-1) + ) + assert np.all( + tbt.transmission(right, left) + 1e-7 + >= tbt.transmission_eig(right, left).sum(-1) + ) # Check that we can't retrieve from same to same electrode with pytest.raises(ValueError): @@ -139,62 +147,86 @@ def test_1_graphene_all_content(sisl_files): with pytest.raises(ValueError): tbt.transmission_eig(left, left) - assert np.allclose(tbt.transmission(left, right, kavg=False), - tbt.transmission(right, left, kavg=False)) + assert np.allclose( + tbt.transmission(left, right, kavg=False), + tbt.transmission(right, left, kavg=False), + ) # Both methods should be identical for simple bulk systems - assert np.allclose(tbt.reflection(left), tbt.reflection(left, from_single=True), atol=1e-5) + assert np.allclose( + tbt.reflection(left), tbt.reflection(left, from_single=True), atol=1e-5 + ) # Also check for each k for ik in range(nk): - assert np.allclose(tbt.transmission(left, right, ik), - tbt.transmission(right, left, ik)) - assert np.allclose(tbt.transmission_eig(left, right, ik), - tbt.transmission_eig(right, left, ik)) - assert np.all(tbt.transmission(left, right, ik) + 1e-7 >= tbt.transmission_eig(left, right, ik).sum(-1)) - assert np.all(tbt.transmission(right, left, ik) + 1e-7 >= tbt.transmission_eig(right, left, ik).sum(-1)) - assert np.allclose(tbt.DOS(kavg=ik), tbt.ADOS(left, kavg=ik) + tbt.ADOS(right, kavg=ik)) - assert np.allclose(tbt.DOS(E=0.195, kavg=ik), tbt.ADOS(left, E=0.195, kavg=ik) + tbt.ADOS(right, E=0.195, kavg=ik)) + assert np.allclose( + tbt.transmission(left, right, ik), tbt.transmission(right, left, ik) + ) + assert np.allclose( + tbt.transmission_eig(left, right, ik), tbt.transmission_eig(right, left, ik) + ) + assert np.all( + tbt.transmission(left, right, ik) + 1e-7 + >= tbt.transmission_eig(left, right, ik).sum(-1) + ) + assert np.all( + tbt.transmission(right, left, ik) + 1e-7 + >= tbt.transmission_eig(right, left, ik).sum(-1) + ) + assert np.allclose( + tbt.DOS(kavg=ik), tbt.ADOS(left, kavg=ik) + tbt.ADOS(right, kavg=ik) + ) + assert np.allclose( + tbt.DOS(E=0.195, kavg=ik), + tbt.ADOS(left, E=0.195, kavg=ik) + tbt.ADOS(right, E=0.195, kavg=ik), + ) # Check that norm returns correct values assert tbt.norm() == 1 - assert tbt.norm(norm='all') == tbt.no_d - assert tbt.norm(norm='atom') == tbt.norm(norm='orbital') + assert tbt.norm(norm="all") == tbt.no_d + assert tbt.norm(norm="atom") == tbt.norm(norm="orbital") # Check atom is equivalent to orbital - for norm in ['atom', 'orbital']: - assert tbt.norm(0, norm=norm) == 0. - assert tbt.norm(3*4, norm=norm) == 1 - assert tbt.norm(range(3*4, 3*5), norm=norm) == 3 + for norm in ["atom", "orbital"]: + assert tbt.norm(0, norm=norm) == 0.0 + assert tbt.norm(3 * 4, norm=norm) == 1 + assert tbt.norm(range(3 * 4, 3 * 5), norm=norm) == 3 # Assert sum(ADOS) == DOS assert np.allclose(tbt.DOS(), tbt.ADOS(left) + tbt.ADOS(right)) - assert np.allclose(tbt.DOS(sum=False), tbt.ADOS(left, sum=False) + tbt.ADOS(right, sum=False)) + assert np.allclose( + tbt.DOS(sum=False), tbt.ADOS(left, sum=False) + tbt.ADOS(right, sum=False) + ) # Now check orbital resolved DOS - assert np.allclose(tbt.DOS(sum=False), tbt.ADOS(left, sum=False) + tbt.ADOS(right, sum=False)) + assert np.allclose( + tbt.DOS(sum=False), tbt.ADOS(left, sum=False) + tbt.ADOS(right, sum=False) + ) # Current must be 0 when the chemical potentials are equal - assert tbt.current(left, right) == pytest.approx(0.) - assert tbt.current(right, left) == pytest.approx(0.) + assert tbt.current(left, right) == pytest.approx(0.0) + assert tbt.current(right, left) == pytest.approx(0.0) high_low = tbt.current_parameter(left, 0.5, 0.0025, right, -0.5, 0.0025) low_high = tbt.current_parameter(left, -0.5, 0.0025, right, 0.5, 0.0025) - assert high_low > 0. - assert low_high < 0. - assert - high_low == pytest.approx(low_high) + assert high_low > 0.0 + assert low_high < 0.0 + assert -high_low == pytest.approx(low_high) with pytest.warns(sisl.SislWarning): - tbt.current_parameter(left, -10., 0.0025, right, 10., 0.0025) + tbt.current_parameter(left, -10.0, 0.0025, right, 10.0, 0.0025) with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") # Since this is a perfect system there should be *no* QM shot-noise # Also, the shot-noise is related to the applied bias, so NO shot-noise - assert np.allclose((tbt.shot_noise(left, right, kavg=False) * tbt.wkpt.reshape(-1, 1)).sum(0), 0.) - assert np.allclose(tbt.shot_noise(left, right), 0.) - assert np.allclose(tbt.shot_noise(right, left), 0.) - assert np.allclose(tbt.shot_noise(left, right, kavg=1), 0.) + assert np.allclose( + (tbt.shot_noise(left, right, kavg=False) * tbt.wkpt.reshape(-1, 1)).sum(0), + 0.0, + ) + assert np.allclose(tbt.shot_noise(left, right), 0.0) + assert np.allclose(tbt.shot_noise(right, left), 0.0) + assert np.allclose(tbt.shot_noise(left, right, kavg=1), 0.0) # Since the data-file does not contain all T-eigs (only the first two) # we can't correctly calculate the fano factors @@ -203,58 +235,75 @@ def test_1_graphene_all_content(sisl_files): # smearing. # When calculating the Fano factor it is extremely important that the zero_T is *sufficient* # I don't now which value is *good* - assert np.all((tbt.fano(left, right, kavg=False) * tbt.wkpt.reshape(-1, 1)).sum(0) <= 1) + assert np.all( + (tbt.fano(left, right, kavg=False) * tbt.wkpt.reshape(-1, 1)).sum(0) <= 1 + ) assert np.all(tbt.fano(left, right) <= 1) assert np.all(tbt.fano(right, left) <= 1) assert np.all(tbt.fano(left, right, kavg=0) <= 1) # Neither should the noise_power exist - assert (tbt.noise_power(right, left, kavg=False) * tbt.wkpt).sum() == pytest.approx(0.) - assert tbt.noise_power(right, left) == pytest.approx(0.) - assert tbt.noise_power(right, left, kavg=0) == pytest.approx(0.) + assert ( + tbt.noise_power(right, left, kavg=False) * tbt.wkpt + ).sum() == pytest.approx(0.0) + assert tbt.noise_power(right, left) == pytest.approx(0.0) + assert tbt.noise_power(right, left, kavg=0) == pytest.approx(0.0) # Check specific DOS queries DOS = tbt.DOS ADOS = tbt.ADOS - assert DOS(2, atoms=True, sum=False).size == geom.names['Device'].size - assert np.allclose(DOS(2, atoms='Device', sum=False), DOS(2, atoms=True, sum=False)) - assert DOS(2, orbitals=True, sum=False).size == geom.a2o('Device', all=True).size - assert ADOS(left, 2, atoms=True, sum=False).size == geom.names['Device'].size - assert ADOS(left, 2, orbitals=True, sum=False).size == geom.a2o('Device', all=True).size - assert np.allclose(ADOS(left, 2, atoms='Device', sum=False), ADOS(left, 2, atoms=True, sum=False)) - - atoms = range(8, 40) # some in device, some not in device - for o in ['atoms', 'orbitals']: + assert DOS(2, atoms=True, sum=False).size == geom.names["Device"].size + assert np.allclose(DOS(2, atoms="Device", sum=False), DOS(2, atoms=True, sum=False)) + assert DOS(2, orbitals=True, sum=False).size == geom.a2o("Device", all=True).size + assert ADOS(left, 2, atoms=True, sum=False).size == geom.names["Device"].size + assert ( + ADOS(left, 2, orbitals=True, sum=False).size + == geom.a2o("Device", all=True).size + ) + assert np.allclose( + ADOS(left, 2, atoms="Device", sum=False), ADOS(left, 2, atoms=True, sum=False) + ) + + atoms = range(8, 40) # some in device, some not in device + for o in ["atoms", "orbitals"]: opt = {o: atoms} for E in [None, 2, 4]: assert np.allclose(DOS(E), ADOS(left, E) + ADOS(right, E)) - assert np.allclose(DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt)) + assert np.allclose( + DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt) + ) - opt['sum'] = False + opt["sum"] = False for E in [None, 2, 4]: assert np.allclose(DOS(E), ADOS(left, E) + ADOS(right, E)) - assert np.allclose(DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt)) + assert np.allclose( + DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt) + ) - opt['sum'] = True - opt['norm'] = o[:-1] + opt["sum"] = True + opt["norm"] = o[:-1] for E in [None, 2, 4]: assert np.allclose(DOS(E), ADOS(left, E) + ADOS(right, E)) - assert np.allclose(DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt)) + assert np.allclose( + DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt) + ) - opt['sum'] = False + opt["sum"] = False for E in [None, 2, 4]: assert np.allclose(DOS(E), ADOS(left, E) + ADOS(right, E)) - assert np.allclose(DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt)) + assert np.allclose( + DOS(E, **opt), ADOS(left, E, **opt) + ADOS(right, E, **opt) + ) # Check orbital currents E = 201 # Sum of orbital current should be 0 (in == out) orb_left = tbt.orbital_transmission(E, left) orb_right = tbt.orbital_transmission(E, right) - assert orb_left.sum() == pytest.approx(0., abs=1e-7) - assert orb_right.sum() == pytest.approx(0., abs=1e-7) + assert orb_left.sum() == pytest.approx(0.0, abs=1e-7) + assert orb_right.sum() == pytest.approx(0.0, abs=1e-7) d1 = np.arange(12, 24).reshape(-1, 1) d2 = np.arange(24, 36).reshape(-1, 1) @@ -264,21 +313,26 @@ def test_1_graphene_all_content(sisl_files): assert orb_right[d2, d1.T].sum() == pytest.approx(-orb_right[d1, d2.T].sum()) orb_left.sort_indices() - atom_left = tbt.bond_transmission(E, left, what='all') + atom_left = tbt.bond_transmission(E, left, what="all") atom_left.sort_indices() assert np.allclose(orb_left.data, atom_left.data) assert np.allclose(orb_left.data, tbt.sparse_orbital_to_atom(orb_left).data) orb_right.sort_indices() - atom_right = tbt.bond_transmission(E, right, what='all') + atom_right = tbt.bond_transmission(E, right, what="all") atom_right.sort_indices() assert np.allclose(orb_right.data, atom_right.data) assert np.allclose(orb_right.data, tbt.sparse_orbital_to_atom(orb_right).data) # Calculate the atom current # For 1-orbital systems the activity and non-activity are equivalent - assert np.allclose(tbt.atom_transmission(E, left), tbt.atom_transmission(E, left, activity=False)) + assert np.allclose( + tbt.atom_transmission(E, left), tbt.atom_transmission(E, left, activity=False) + ) tbt.vector_transmission(E, left) - assert np.allclose(tbt.sparse_atom_to_vector(atom_left) / 2, tbt.vector_transmission(E, left, what='all')) + assert np.allclose( + tbt.sparse_atom_to_vector(atom_left) / 2, + tbt.vector_transmission(E, left, what="all"), + ) # Check COOP curves coop = tbt.orbital_COOP(E) @@ -357,19 +411,19 @@ def test_1_graphene_all_content(sisl_files): @pytest.mark.slow def test_1_graphene_all_tbtav(sisl_files, sisl_tmp): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) - f = sisl_tmp('1_graphene_all.TBT.AV.nc', _dir) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) + f = sisl_tmp("1_graphene_all.TBT.AV.nc", _dir) tbt.write_tbtav(f) def test_1_graphene_all_fail_kavg(sisl_files, sisl_tmp): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) with pytest.raises(ValueError): tbt.transmission(kavg=[0, 1]) def test_1_graphene_sparse_current(sisl_files, sisl_tmp): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) J = tbt.orbital_current() assert np.allclose(J.data, 0) @@ -391,7 +445,7 @@ def test_1_graphene_sparse_current(sisl_files, sisl_tmp): @pytest.mark.filterwarnings("ignore:.*requesting energy") def test_1_graphene_all_fail_kavg_E(sisl_files, sisl_tmp): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) with pytest.raises(ValueError): tbt.orbital_COOP(kavg=[0, 1], E=0.1) @@ -399,6 +453,7 @@ def test_1_graphene_all_fail_kavg_E(sisl_files, sisl_tmp): def test_1_graphene_all_ArgumentParser(sisl_files, sisl_tmp): pytest.importorskip("matplotlib", reason="matplotlib not available") import matplotlib as mpl + mpl.rcParams["text.usetex"] = False # Local routine to run the collected actions @@ -410,137 +465,158 @@ def run(ns): ns._actions_run = False ns._actions = [] - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) p, ns = tbt.ArgumentParser() p.parse_args([], namespace=ns) p, ns = tbt.ArgumentParser() - out = p.parse_args(['--energy', ' -1.995:1.995'], namespace=ns) + out = p.parse_args(["--energy", " -1.995:1.995"], namespace=ns) assert not out._actions_run run(out) p, ns = tbt.ArgumentParser() - out = p.parse_args(['--kpoint', '1'], namespace=ns) + out = p.parse_args(["--kpoint", "1"], namespace=ns) assert out._krng run(out) assert out._krng == 1 p, ns = tbt.ArgumentParser() - out = p.parse_args(['--norm', 'orbital'], namespace=ns) + out = p.parse_args(["--norm", "orbital"], namespace=ns) run(out) - assert out._norm == 'orbital' + assert out._norm == "orbital" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--norm', 'atom'], namespace=ns) + out = p.parse_args(["--norm", "atom"], namespace=ns) run(out) - assert out._norm == 'atom' + assert out._norm == "atom" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--kpoint', '1', '--norm', 'orbital'], namespace=ns) + out = p.parse_args(["--kpoint", "1", "--norm", "orbital"], namespace=ns) run(out) assert out._krng == 1 - assert out._norm == 'orbital' + assert out._norm == "orbital" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--atom', '10:11,14'], namespace=ns) + out = p.parse_args(["--atom", "10:11,14"], namespace=ns) run(out) - assert out._Ovalue == '10:11,14' + assert out._Ovalue == "10:11,14" # Only atom 14 is in the device region assert np.all(out._Orng + 1 == [14]) p, ns = tbt.ArgumentParser() - out = p.parse_args(['--atom', '10:11,12,14:20'], namespace=ns) + out = p.parse_args(["--atom", "10:11,12,14:20"], namespace=ns) run(out) - assert out._Ovalue == '10:11,12,14:20' + assert out._Ovalue == "10:11,12,14:20" # Only 13-48 is in the device assert np.all(out._Orng + 1 == [14, 15, 16, 17, 18, 19, 20]) p, ns = tbt.ArgumentParser() - out = p.parse_args(['--transmission', 'Left', 'Right'], namespace=ns) + out = p.parse_args(["--transmission", "Left", "Right"], namespace=ns) run(out) assert len(out._data) == 2 - assert out._data_header[0][0] == 'E' - assert out._data_header[1][0] == 'T' + assert out._data_header[0][0] == "E" + assert out._data_header[1][0] == "T" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--transmission', 'Left', 'Right', - '--transmission-bulk', 'Left'], namespace=ns) + out = p.parse_args( + ["--transmission", "Left", "Right", "--transmission-bulk", "Left"], namespace=ns + ) run(out) assert len(out._data) == 3 - assert out._data_header[0][0] == 'E' - assert out._data_header[1][0] == 'T' - assert out._data_header[2][:2] == 'BT' + assert out._data_header[0][0] == "E" + assert out._data_header[1][0] == "T" + assert out._data_header[2][:2] == "BT" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--dos', '--dos', 'Left', '--ados', 'Right'], namespace=ns) + out = p.parse_args(["--dos", "--dos", "Left", "--ados", "Right"], namespace=ns) run(out) assert len(out._data) == 4 - assert out._data_header[0][0] == 'E' - assert out._data_header[1][0] == 'D' - assert out._data_header[2][:2] == 'AD' - assert out._data_header[3][:2] == 'AD' + assert out._data_header[0][0] == "E" + assert out._data_header[1][0] == "D" + assert out._data_header[2][:2] == "AD" + assert out._data_header[3][:2] == "AD" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--bulk-dos', 'Left', '--ados', 'Right'], namespace=ns) + out = p.parse_args(["--bulk-dos", "Left", "--ados", "Right"], namespace=ns) run(out) assert len(out._data) == 3 - assert out._data_header[0][0] == 'E' - assert out._data_header[1][:2] == 'BD' - assert out._data_header[2][:2] == 'AD' + assert out._data_header[0][0] == "E" + assert out._data_header[1][:2] == "BD" + assert out._data_header[2][:2] == "AD" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--transmission-eig', 'Left', 'Right'], namespace=ns) + out = p.parse_args(["--transmission-eig", "Left", "Right"], namespace=ns) run(out) - assert out._data_header[0][0] == 'E' + assert out._data_header[0][0] == "E" for i in range(1, len(out._data)): - assert out._data_header[i][:4] == 'Teig' + assert out._data_header[i][:4] == "Teig" p, ns = tbt.ArgumentParser() - out = p.parse_args(['--info'], namespace=ns) + out = p.parse_args(["--info"], namespace=ns) # Test output - f = sisl_tmp('1_graphene_all.dat', _dir) + f = sisl_tmp("1_graphene_all.dat", _dir) p, ns = tbt.ArgumentParser() - out = p.parse_args(['--transmission-eig', 'Left', 'Right', '--out', f], namespace=ns) + out = p.parse_args( + ["--transmission-eig", "Left", "Right", "--out", f], namespace=ns + ) assert len(out._data) == 0 - f1 = sisl_tmp('1_graphene_all_1.dat', _dir) - f2 = sisl_tmp('1_graphene_all_2.dat', _dir) + f1 = sisl_tmp("1_graphene_all_1.dat", _dir) + f2 = sisl_tmp("1_graphene_all_2.dat", _dir) p, ns = tbt.ArgumentParser() - out = p.parse_args(['--transmission', 'Left', 'Right', '--out', f1, - '--dos', '--atom', '12:2:48', '--dos', 'Right', '--ados', 'Left', '--out', f2], namespace=ns) + out = p.parse_args( + [ + "--transmission", + "Left", + "Right", + "--out", + f1, + "--dos", + "--atom", + "12:2:48", + "--dos", + "Right", + "--ados", + "Left", + "--out", + f2, + ], + namespace=ns, + ) d = sisl.io.tableSile(f1).read_data() assert len(d) == 2 d = sisl.io.tableSile(f2).read_data() assert len(d) == 4 - assert np.allclose(d[1, :], (d[2, :] + d[3, :])* 2) + assert np.allclose(d[1, :], (d[2, :] + d[3, :]) * 2) assert np.allclose(d[2, :], d[3, :]) - f = sisl_tmp('1_graphene_all_T.png', _dir) + f = sisl_tmp("1_graphene_all_T.png", _dir) p, ns = tbt.ArgumentParser() - out = p.parse_args(['--transmission', 'Left', 'Right', - '--transmission-bulk', 'Left', - '--plot', f], namespace=ns) + out = p.parse_args( + ["--transmission", "Left", "Right", "--transmission-bulk", "Left", "--plot", f], + namespace=ns, + ) # Requesting an orbital outside of the device region def test_1_graphene_all_warn_orbital(sisl_files): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) with pytest.warns(sisl.SislWarning): tbt.o2p(1) # Requesting an atom outside of the device region def test_1_graphene_all_warn_atom(sisl_files): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) with pytest.warns(sisl.SislWarning): tbt.a2p(1) def test_1_graphene_all_sparse_data_isc_request(sisl_files): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) # get supercell with isc sc = tbt.read_lattice() @@ -555,14 +631,13 @@ def test_1_graphene_all_sparse_data_isc_request(sisl_files): # partial summed isc # Test that the full matrix and individual access is the same - J_sum = sum(tbt.orbital_transmission(204, elec, isc=isc) - for isc in sc.sc_off) + J_sum = sum(tbt.orbital_transmission(204, elec, isc=isc) for isc in sc.sc_off) assert J_sum.nnz == J_all.nnz assert (J_sum - J_all).nnz == 0 def test_1_graphene_all_sparse_data_orbitals(sisl_files): - tbt = sisl.get_sile(sisl_files(_dir, '1_graphene_all.TBT.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "1_graphene_all.TBT.nc")) # request the full matrix J_all = tbt.orbital_transmission(204, 0) diff --git a/src/sisl/io/tbtrans/tests/test_tbtproj.py b/src/sisl/io/tbtrans/tests/test_tbtproj.py index 2b123bc499..aa93598881 100644 --- a/src/sisl/io/tbtrans/tests/test_tbtproj.py +++ b/src/sisl/io/tbtrans/tests/test_tbtproj.py @@ -11,15 +11,15 @@ import sisl pytestmark = [pytest.mark.io, pytest.mark.tbtrans] -_dir = osp.join('sisl', 'io', 'tbtrans') +_dir = osp.join("sisl", "io", "tbtrans") netCDF4 = pytest.importorskip("netCDF4") @pytest.mark.slow def test_2_projection_content(sisl_files): - tbt = sisl.get_sile(sisl_files(_dir, '2_projection.TBT.nc')) - tbtp = sisl.get_sile(sisl_files(_dir, '2_projection.TBT.Proj.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "2_projection.TBT.nc")) + tbtp = sisl.get_sile(sisl_files(_dir, "2_projection.TBT.Proj.nc")) assert np.allclose(tbt.E, tbtp.E) assert np.allclose(tbt.kpt, tbtp.kpt) @@ -36,27 +36,31 @@ def test_2_projection_content(sisl_files): for mol in tbtp.molecules: for proj in tbtp.projections(mol): t1 = tbtp.transmission((left, mol, proj), (right, mol, proj)) - t2 = tbtp.transmission('.'.join((left, mol, proj)), '.'.join((right, mol, proj))) + t2 = tbtp.transmission( + ".".join((left, mol, proj)), ".".join((right, mol, proj)) + ) assert np.allclose(t1, t2) te1 = tbtp.transmission_eig((left, mol, proj), (right, mol, proj)) - te2 = tbtp.transmission_eig('.'.join((left, mol, proj)), '.'.join((right, mol, proj))) + te2 = tbtp.transmission_eig( + ".".join((left, mol, proj)), ".".join((right, mol, proj)) + ) assert np.allclose(te1, te2) assert np.allclose(t1, te1.sum(-1)) assert np.allclose(t2, te2.sum(-1)) # Check eigenstate - es = tbtp.eigenstate('C60') - assert len(es) == 3 # 1-HOMO, 2-LUMO - assert (es.eig < 0.).nonzero()[0].size == 1 - assert (es.eig > 0.).nonzero()[0].size == 2 + es = tbtp.eigenstate("C60") + assert len(es) == 3 # 1-HOMO, 2-LUMO + assert (es.eig < 0.0).nonzero()[0].size == 1 + assert (es.eig > 0.0).nonzero()[0].size == 2 assert np.allclose(es.norm2(), 1) @pytest.mark.slow def test_2_projection_tbtav(sisl_files, sisl_tmp): - tbt = sisl.get_sile(sisl_files(_dir, '2_projection.TBT.Proj.nc')) - f = sisl_tmp('2_projection.TBT.Proj.AV.nc', _dir) + tbt = sisl.get_sile(sisl_files(_dir, "2_projection.TBT.Proj.nc")) + f = sisl_tmp("2_projection.TBT.Proj.AV.nc", _dir) tbt.write_tbtav(f) @@ -73,91 +77,126 @@ def run(ns): ns._actions_run = False ns._actions = [] - tbt = sisl.get_sile(sisl_files(_dir, '2_projection.TBT.Proj.nc')) + tbt = sisl.get_sile(sisl_files(_dir, "2_projection.TBT.Proj.nc")) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) p.parse_args([], namespace=ns) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--energy', ' -1.995:1.995'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args(["--energy", " -1.995:1.995"], namespace=ns) assert not out._actions_run run(out) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--norm', 'orbital'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args(["--norm", "orbital"], namespace=ns) run(out) - assert out._norm == 'orbital' + assert out._norm == "orbital" - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--norm', 'atom'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args(["--norm", "atom"], namespace=ns) run(out) - assert out._norm == 'atom' + assert out._norm == "atom" - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--atom', '10:11,14'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args(["--atom", "10:11,14"], namespace=ns) run(out) - assert out._Ovalue == '10:11,14' + assert out._Ovalue == "10:11,14" # Only atom 14 is in the device region assert np.all(out._Orng + 1 == [14]) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--atom', '10:11,12,14:20'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args(["--atom", "10:11,12,14:20"], namespace=ns) run(out) - assert out._Ovalue == '10:11,12,14:20' + assert out._Ovalue == "10:11,12,14:20" # Only 13-72 is in the device assert np.all(out._Orng + 1 == [14, 15, 16, 17, 18, 19, 20]) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--transmission', 'Left.C60.HOMO', 'Right.C60.HOMO'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args( + ["--transmission", "Left.C60.HOMO", "Right.C60.HOMO"], namespace=ns + ) run(out) assert len(out._data) == 2 - assert out._data_header[0][0] == 'E' - assert out._data_header[1][0] == 'T' + assert out._data_header[0][0] == "E" + assert out._data_header[1][0] == "T" - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--molecules', '-P', 'C60'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args(["--molecules", "-P", "C60"], namespace=ns) run(out) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--transmission', 'Left', 'Right.C60.LUMO', - '--transmission', 'Left.C60.LUMO', 'Right'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args( + [ + "--transmission", + "Left", + "Right.C60.LUMO", + "--transmission", + "Left.C60.LUMO", + "Right", + ], + namespace=ns, + ) run(out) assert len(out._data) == 3 - assert out._data_header[0][0] == 'E' - assert out._data_header[1][0] == 'T' - assert out._data_header[2][0] == 'T' - - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--ados', 'Left.C60.HOMO', '--ados', 'Left.C60.LUMO'], namespace=ns) + assert out._data_header[0][0] == "E" + assert out._data_header[1][0] == "T" + assert out._data_header[2][0] == "T" + + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args( + ["--ados", "Left.C60.HOMO", "--ados", "Left.C60.LUMO"], namespace=ns + ) run(out) assert len(out._data) == 3 - assert out._data_header[0][0] == 'E' - assert out._data_header[1][:2] == 'AD' - assert out._data_header[2][:2] == 'AD' - - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--transmission-eig', 'Left.C60.HOMO', 'Right.C60.LUMO'], namespace=ns) + assert out._data_header[0][0] == "E" + assert out._data_header[1][:2] == "AD" + assert out._data_header[2][:2] == "AD" + + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args( + ["--transmission-eig", "Left.C60.HOMO", "Right.C60.LUMO"], namespace=ns + ) run(out) - assert out._data_header[0][0] == 'E' + assert out._data_header[0][0] == "E" for i in range(1, len(out._data)): - assert out._data_header[i][:4] == 'Teig' + assert out._data_header[i][:4] == "Teig" - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--info'], namespace=ns) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args(["--info"], namespace=ns) # Test output - f = sisl_tmp('2_projection.dat', _dir) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--transmission-eig', 'Left', 'Right.C60.HOMO', '--out', f], namespace=ns) + f = sisl_tmp("2_projection.dat", _dir) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args( + ["--transmission-eig", "Left", "Right.C60.HOMO", "--out", f], namespace=ns + ) assert len(out._data) == 0 - f1 = sisl_tmp('2_projection_1.dat', _dir) - f2 = sisl_tmp('2_projection_2.dat', _dir) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--transmission', 'Left', 'Right.C60.HOMO', '--out', f1, - '--ados', 'Left.C60.HOMO', - '--atom', '13:2:72', '--ados', 'Left.C60.HOMO', - '--atom', '14:2:72', '--ados', 'Left.C60.HOMO', '--out', f2], namespace=ns) + f1 = sisl_tmp("2_projection_1.dat", _dir) + f2 = sisl_tmp("2_projection_2.dat", _dir) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args( + [ + "--transmission", + "Left", + "Right.C60.HOMO", + "--out", + f1, + "--ados", + "Left.C60.HOMO", + "--atom", + "13:2:72", + "--ados", + "Left.C60.HOMO", + "--atom", + "14:2:72", + "--ados", + "Left.C60.HOMO", + "--out", + f2, + ], + namespace=ns, + ) d = sisl.io.tableSile(f1).read_data() assert len(d) == 2 @@ -165,8 +204,18 @@ def run(ns): assert len(d) == 4 assert np.allclose(d[1, :], d[2, :] + d[3, :]) - f = sisl_tmp('2_projection_T.png', _dir) - p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler='resolve')) - out = p.parse_args(['--transmission', 'Left', 'Right.C60.HOMO', - '--transmission', 'Left.C60.HOMO', 'Right.C60.HOMO', - '--plot', f], namespace=ns) + f = sisl_tmp("2_projection_T.png", _dir) + p, ns = tbt.ArgumentParser(argparse.ArgumentParser(conflict_handler="resolve")) + out = p.parse_args( + [ + "--transmission", + "Left", + "Right.C60.HOMO", + "--transmission", + "Left.C60.HOMO", + "Right.C60.HOMO", + "--plot", + f, + ], + namespace=ns, + ) diff --git a/src/sisl/io/tests/test_cube.py b/src/sisl/io/tests/test_cube.py index d526cce657..b90e3e3b05 100644 --- a/src/sisl/io/tests/test_cube.py +++ b/src/sisl/io/tests/test_cube.py @@ -10,11 +10,11 @@ from sisl.io.cube import * pytestmark = [pytest.mark.io, pytest.mark.generic] -_dir = osp.join('sisl', 'io') +_dir = osp.join("sisl", "io") def test_default(sisl_tmp): - f = sisl_tmp('GRID.cube', _dir) + f = sisl_tmp("GRID.cube", _dir) grid = Grid(0.2) grid.grid = np.random.rand(*grid.shape) grid.write(f) @@ -25,7 +25,7 @@ def test_default(sisl_tmp): def test_default_size(sisl_tmp): - f = sisl_tmp('GRID.cube', _dir) + f = sisl_tmp("GRID.cube", _dir) grid = Grid(0.2, lattice=2.0) grid.grid = np.random.rand(*grid.shape) grid.write(f) @@ -36,8 +36,12 @@ def test_default_size(sisl_tmp): def test_geometry(sisl_tmp): - f = sisl_tmp('GRID.cube', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) + f = sisl_tmp("GRID.cube", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) grid = Grid(0.2, geometry=geom) grid.grid = np.random.rand(*grid.shape) grid.write(f) @@ -57,16 +61,20 @@ def test_geometry(sisl_tmp): def test_imaginary(sisl_tmp): - fr = sisl_tmp('GRID_real.cube', _dir) - fi = sisl_tmp('GRID_imag.cube', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) + fr = sisl_tmp("GRID_real.cube", _dir) + fi = sisl_tmp("GRID_imag.cube", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) grid = Grid(0.2, geometry=geom, dtype=np.complex128) - grid.grid = np.random.rand(*grid.shape) + 1j*np.random.rand(*grid.shape) + grid.grid = np.random.rand(*grid.shape) + 1j * np.random.rand(*grid.shape) grid.write(fr) grid.write(fi, imag=True) read = grid.read(fr) read_i = grid.read(fi) - read.grid = read.grid + 1j*read_i.grid + read.grid = read.grid + 1j * read_i.grid assert np.allclose(grid.grid, read.grid) assert not grid.geometry is None assert not read.geometry is None @@ -80,11 +88,15 @@ def test_imaginary(sisl_tmp): def test_imaginary_fail_shape(sisl_tmp): - fr = sisl_tmp('GRID_real.cube', _dir) - fi = sisl_tmp('GRID_imag.cube', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) + fr = sisl_tmp("GRID_real.cube", _dir) + fi = sisl_tmp("GRID_imag.cube", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) grid = Grid(0.2, geometry=geom, dtype=np.complex128) - grid.grid = np.random.rand(*grid.shape) + 1j*np.random.rand(*grid.shape) + grid.grid = np.random.rand(*grid.shape) + 1j * np.random.rand(*grid.shape) grid.write(fr) # Assert it fails on shape @@ -95,11 +107,15 @@ def test_imaginary_fail_shape(sisl_tmp): def test_imaginary_fail_geometry(sisl_tmp): - fr = sisl_tmp('GRID_real.cube', _dir) - fi = sisl_tmp('GRID_imag.cube', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) + fr = sisl_tmp("GRID_real.cube", _dir) + fi = sisl_tmp("GRID_imag.cube", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) grid = Grid(0.2, geometry=geom, dtype=np.complex128) - grid.grid = np.random.rand(*grid.shape) + 1j*np.random.rand(*grid.shape) + grid.grid = np.random.rand(*grid.shape) + 1j * np.random.rand(*grid.shape) grid.write(fr) # Assert it fails on geometry diff --git a/src/sisl/io/tests/test_object.py b/src/sisl/io/tests/test_object.py index ae22e01881..99630da557 100644 --- a/src/sisl/io/tests/test_object.py +++ b/src/sisl/io/tests/test_object.py @@ -9,7 +9,14 @@ import numpy as np import pytest -from sisl import DensityMatrix, EnergyDensityMatrix, Geometry, Grid, Hamiltonian, Lattice +from sisl import ( + DensityMatrix, + EnergyDensityMatrix, + Geometry, + Grid, + Hamiltonian, + Lattice, +) from sisl._environ import sisl_environ from sisl.io import * from sisl.io.siesta.binaries import _gfSileSiesta @@ -69,7 +76,6 @@ def test_get_out_context(): class TestObject: - def test_siesta_sources(self): pytest.importorskip("sisl.io.siesta._siesta") @@ -97,7 +103,9 @@ def test_direct_string_instantiation(self, sisl_tmp): assert isinstance(sile, Sile) assert not os.path.exists(fp) - @pytest.mark.parametrize("sile", _fnames("test", ["cube", "CUBE", "cube.gz", "CUBE.gz"])) + @pytest.mark.parametrize( + "sile", _fnames("test", ["cube", "CUBE", "cube.gz", "CUBE.gz"]) + ) def test_cube(self, sile): s = gs(sile) for obj in [BaseSile, Sile, cubeSile]: @@ -133,7 +141,9 @@ def test_scaleup_rham(self, sile): for obj in [BaseSile, Sile, SileScaleUp, rhamSileScaleUp]: assert isinstance(s, obj) - @pytest.mark.parametrize("sile", _fnames("test", ["fdf", "fdf.gz", "FDF.gz", "FDF"])) + @pytest.mark.parametrize( + "sile", _fnames("test", ["fdf", "fdf.gz", "FDF.gz", "FDF"]) + ) def test_siesta_fdf(self, sile): s = gs(sile) for obj in [BaseSile, Sile, SileSiesta, fdfSileSiesta]: @@ -194,7 +204,13 @@ def test_tbtrans_nc(self, sile): def test_phtrans_nc(self, sile): pytest.importorskip("netCDF4") s = gs(sile, _open=False) - for obj in [BaseSile, SileCDF, SileCDFTBtrans, tbtncSileTBtrans, phtncSilePHtrans]: + for obj in [ + BaseSile, + SileCDF, + SileCDFTBtrans, + tbtncSileTBtrans, + phtncSilePHtrans, + ]: assert isinstance(s, obj) @pytest.mark.parametrize("sile", _fnames("CONTCAR", ["", "gz"])) @@ -209,19 +225,25 @@ def test_vasp_poscar(self, sile): for obj in [BaseSile, Sile, SileVASP, carSileVASP]: assert isinstance(s, obj) - @pytest.mark.parametrize("sile", _fnames("test", ["xyz", "XYZ", "xyz.gz", "XYZ.gz"])) + @pytest.mark.parametrize( + "sile", _fnames("test", ["xyz", "XYZ", "xyz.gz", "XYZ.gz"]) + ) def test_xyz(self, sile): s = gs(sile) for obj in [BaseSile, Sile, xyzSile]: assert isinstance(s, obj) - @pytest.mark.parametrize("sile", _fnames("test", ["molf", "MOLF", "molf.gz", "MOLF.gz"])) + @pytest.mark.parametrize( + "sile", _fnames("test", ["molf", "MOLF", "molf.gz", "MOLF.gz"]) + ) def test_molf(self, sile): s = gs(sile) for obj in [BaseSile, Sile, moldenSile]: assert isinstance(s, obj) - @pytest.mark.parametrize("sile", _fnames("test", ["xsf", "XSF", "xsf.gz", "XSF.gz"])) + @pytest.mark.parametrize( + "sile", _fnames("test", ["xsf", "XSF", "xsf.gz", "XSF.gz"]) + ) def test_xsf(self, sile): s = gs(sile) for obj in [BaseSile, Sile, xsfSile]: @@ -256,20 +278,22 @@ def test_read_write_lattice(self, sisl_tmp, sisl_system, sile): if issubclass(sile, (_ncSileTBtrans, deltancSileTBtrans)): return if sys.platform.startswith("win") and issubclass(sile, chgSileVASP): - pytest.xfail("Windows reading/writing supercell fails for some unknown reason") + pytest.xfail( + "Windows reading/writing supercell fails for some unknown reason" + ) # Write sile(f, mode="w").write_lattice(L) # Read 1 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: l = s.read_lattice() - assert l.equal(L, tol=1e-3) # pdb files have 8.3 for atomic coordinates + assert l.equal(L, tol=1e-3) # pdb files have 8.3 for atomic coordinates except UnicodeDecodeError as e: pass # Read 2 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: l = Lattice.read(s) assert l.equal(L, tol=1e-3) except UnicodeDecodeError as e: @@ -287,19 +311,23 @@ def test_read_write_lattice(self, sisl_tmp, sisl_system, sile): if issubclass(sile, (_ncSileTBtrans, deltancSileTBtrans)): return if sys.platform.startswith("win") and issubclass(sile, chgSileVASP): - pytest.xfail("Windows reading/writing supercell fails for some unknown reason") + pytest.xfail( + "Windows reading/writing supercell fails for some unknown reason" + ) # Write sile(f, mode="w").write_lattice(L) # Read 1 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: l = s.read_lattice() - assert l.equal(L, tol=1e-3) # pdb files have 8.3 for atomic coordinates + assert l.equal(L, tol=1e-3) # pdb files have 8.3 for atomic coordinates except UnicodeDecodeError as e: pass - @pytest.mark.parametrize("sile", _my_intersect(["read_geometry"], ["write_geometry"])) + @pytest.mark.parametrize( + "sile", _my_intersect(["read_geometry"], ["write_geometry"]) + ) def test_read_write_geometry(self, sisl_tmp, sisl_system, sile): if issubclass(sile, SileCDF): pytest.importorskip("netCDF4") @@ -311,22 +339,26 @@ def test_read_write_geometry(self, sisl_tmp, sisl_system, sile): if issubclass(sile, (_ncSileTBtrans, deltancSileTBtrans)): return if sys.platform.startswith("win") and issubclass(sile, chgSileVASP): - pytest.xfail("Windows reading/writing supercell fails for some unknown reason") + pytest.xfail( + "Windows reading/writing supercell fails for some unknown reason" + ) # Write sile(f, mode="w").write_geometry(G) # Read 1 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: g = s.read_geometry() if isinstance(g, list): g = g[0] - assert g.equal(G, R=False, tol=1e-3) # pdb files have 8.3 for atomic coordinates + assert g.equal( + G, R=False, tol=1e-3 + ) # pdb files have 8.3 for atomic coordinates except UnicodeDecodeError as e: pass # Read 2 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: g = Geometry.read(s) if isinstance(g, list): g = g[0] @@ -334,7 +366,9 @@ def test_read_write_geometry(self, sisl_tmp, sisl_system, sile): except UnicodeDecodeError as e: pass - @pytest.mark.parametrize("sile", _my_intersect(["read_hamiltonian"], ["write_hamiltonian"])) + @pytest.mark.parametrize( + "sile", _my_intersect(["read_hamiltonian"], ["write_hamiltonian"]) + ) def test_read_write_hamiltonian(self, sisl_tmp, sisl_system, sile): if issubclass(sile, SileCDF): pytest.importorskip("netCDF4") @@ -347,24 +381,26 @@ def test_read_write_hamiltonian(self, sisl_tmp, sisl_system, sile): H.construct([[0.1, 1.45], [0.1, -2.7]]) f = sisl_tmp("test_read_write_hamiltonian.win", _dir) # Write - with sile(f, mode='w') as s: + with sile(f, mode="w") as s: s.write_hamiltonian(H) # Read 1 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: h = s.read_hamiltonian() assert H.spsame(h) except UnicodeDecodeError as e: pass # Read 2 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: h = Hamiltonian.read(s) assert H.spsame(h) except UnicodeDecodeError as e: pass - @pytest.mark.parametrize("sile", _my_intersect(["read_density_matrix"], ["write_density_matrix"])) + @pytest.mark.parametrize( + "sile", _my_intersect(["read_density_matrix"], ["write_density_matrix"]) + ) def test_read_write_density_matrix(self, sisl_tmp, sisl_system, sile): if issubclass(sile, SileCDF): pytest.importorskip("netCDF4") @@ -374,24 +410,27 @@ def test_read_write_density_matrix(self, sisl_tmp, sisl_system, sile): DM.construct([[0.1, 1.45], [0.1, -2.7]]) f = sisl_tmp("test_read_write_density_matrix.win", _dir) # Write - with sile(f, mode='w') as s: + with sile(f, mode="w") as s: s.write_density_matrix(DM) # Read 1 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: dm = s.read_density_matrix(geometry=DM.geometry) assert DM.spsame(dm) except UnicodeDecodeError as e: pass # Read 2 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: dm = DensityMatrix.read(s, geometry=DM.geometry) assert DM.spsame(dm) except UnicodeDecodeError as e: pass - @pytest.mark.parametrize("sile", _my_intersect(["read_energy_density_matrix"], ["write_energy_density_matrix"])) + @pytest.mark.parametrize( + "sile", + _my_intersect(["read_energy_density_matrix"], ["write_energy_density_matrix"]), + ) def test_read_write_energy_density_matrix(self, sisl_tmp, sisl_system, sile): if issubclass(sile, SileCDF): pytest.importorskip("netCDF4") @@ -401,24 +440,26 @@ def test_read_write_energy_density_matrix(self, sisl_tmp, sisl_system, sile): EDM.construct([[0.1, 1.45], [0.1, -2.7]]) f = sisl_tmp("test_read_write_energy_density_matrix.win", _dir) # Write - with sile(f, mode='w') as s: + with sile(f, mode="w") as s: s.write_energy_density_matrix(EDM) # Read 1 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: edm = s.read_energy_density_matrix(geometry=EDM.geometry) assert EDM.spsame(edm) except UnicodeDecodeError as e: pass # Read 2 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: edm = EnergyDensityMatrix.read(s, geometry=EDM.geometry) assert EDM.spsame(edm) except UnicodeDecodeError as e: pass - @pytest.mark.parametrize("sile", _my_intersect(["read_hamiltonian"], ["write_hamiltonian"])) + @pytest.mark.parametrize( + "sile", _my_intersect(["read_hamiltonian"], ["write_hamiltonian"]) + ) def test_read_write_hamiltonian_overlap(self, sisl_tmp, sisl_system, sile): if issubclass(sile, SileCDF): pytest.importorskip("netCDF4") @@ -431,24 +472,26 @@ def test_read_write_hamiltonian_overlap(self, sisl_tmp, sisl_system, sile): H.construct([[0.1, 1.45], [(0.1, 1), (-2.7, 0.1)]]) f = sisl_tmp("test_read_write_hamiltonian_overlap.win", _dir) # Write - with sile(f, mode='w') as s: + with sile(f, mode="w") as s: s.write_hamiltonian(H) # Read 1 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: h = s.read_hamiltonian() assert H.spsame(h) except UnicodeDecodeError as e: pass # Read 2 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: h = Hamiltonian.read(s) assert H.spsame(h) except UnicodeDecodeError as e: pass - @pytest.mark.filterwarnings("ignore", message="*gridncSileSiesta.read_grid cannot determine") + @pytest.mark.filterwarnings( + "ignore", message="*gridncSileSiesta.read_grid cannot determine" + ) @pytest.mark.parametrize("sile", _my_intersect(["read_grid"], ["write_grid"])) def test_read_write_grid(self, sisl_tmp, sisl_system, sile): if issubclass(sile, SileCDF): @@ -475,7 +518,7 @@ def test_read_write_grid(self, sisl_tmp, sisl_system, sile): pass # Read 2 try: - with sile(f, mode='r') as s: + with sile(f, mode="r") as s: g = Grid.read(s) assert np.allclose(g.grid, G.grid, atol=1e-5) except UnicodeDecodeError as e: @@ -510,5 +553,6 @@ class tmpSile(xyzSile, buffer_cls=customBuffer): assert tmpSile._buffer_cls.__name__ == "customBuffer" with pytest.raises(TypeError): + class tmpSile(xyzSile, buffer_cls=object): pass diff --git a/src/sisl/io/tests/test_table.py b/src/sisl/io/tests/test_table.py index 89823a91a5..23cc8d0bec 100644 --- a/src/sisl/io/tests/test_table.py +++ b/src/sisl/io/tests/test_table.py @@ -9,15 +9,15 @@ from sisl.io.table import * pytestmark = [pytest.mark.io, pytest.mark.generic] -_dir = osp.join('sisl', 'io') +_dir = osp.join("sisl", "io") def test_tbl1(sisl_tmp): dat0 = np.arange(2) dat1 = np.arange(2) + 1 - io0 = tableSile(sisl_tmp('t0.dat', _dir), 'w') - io1 = tableSile(sisl_tmp('t1.dat', _dir), 'w') + io0 = tableSile(sisl_tmp("t0.dat", _dir), "w") + io1 = tableSile(sisl_tmp("t1.dat", _dir), "w") io0.write_data(dat0, dat1) io1.write_data((dat0, dat1)) @@ -30,8 +30,8 @@ def test_tbl2(sisl_tmp): dat0 = np.arange(8).reshape(2, 2, 2) dat1 = np.arange(8).reshape(2, 2, 2) + 1 - io0 = tableSile(sisl_tmp('t0.dat', _dir), 'w') - io1 = tableSile(sisl_tmp('t1.dat', _dir), 'w') + io0 = tableSile(sisl_tmp("t0.dat", _dir), "w") + io1 = tableSile(sisl_tmp("t1.dat", _dir), "w") io0.write_data(dat0, dat1) io1.write_data((dat0, dat1)) @@ -46,13 +46,13 @@ def test_tbl3(sisl_tmp): DAT = np.stack([dat0, dat1]) DAT.shape = (-1, 2, 2) - io = tableSile(sisl_tmp('t.dat', _dir), 'w') + io = tableSile(sisl_tmp("t.dat", _dir), "w") io.write_data(dat0, dat1) - dat = tableSile(io.file, 'r').read_data() + dat = tableSile(io.file, "r").read_data() assert np.allclose(dat, DAT) - io = tableSile(io.file, 'w') + io = tableSile(io.file, "w") io.write_data((dat0, dat1)) - dat = tableSile(io.file, 'r').read_data() + dat = tableSile(io.file, "r").read_data() assert np.allclose(dat, DAT) @@ -61,13 +61,13 @@ def test_tbl4(sisl_tmp): dat1 = np.arange(8) + 1 DAT = np.stack([dat0, dat1]) - io = tableSile(sisl_tmp('t.dat', _dir), 'w') + io = tableSile(sisl_tmp("t.dat", _dir), "w") io.write_data(dat0, dat1) - dat = tableSile(io.file, 'r').read_data() + dat = tableSile(io.file, "r").read_data() assert np.allclose(dat, DAT) - io = tableSile(io.file, 'w') + io = tableSile(io.file, "w") io.write_data((dat0, dat1)) - dat = tableSile(io.file, 'r').read_data() + dat = tableSile(io.file, "r").read_data() assert np.allclose(dat, DAT) @@ -76,55 +76,55 @@ def test_tbl_automatic_stack(sisl_tmp): dat1 = np.arange(8).reshape(2, 4) + 1 DAT = np.vstack([dat0, dat1]) - io = tableSile(sisl_tmp('t.dat', _dir), 'w') + io = tableSile(sisl_tmp("t.dat", _dir), "w") io.write_data(dat0, dat1) - dat = tableSile(io.file, 'r').read_data() + dat = tableSile(io.file, "r").read_data() assert np.allclose(dat, DAT) - io = tableSile(io.file, 'w') + io = tableSile(io.file, "w") io.write_data((dat0, dat1)) - dat = tableSile(io.file, 'r').read_data() + dat = tableSile(io.file, "r").read_data() assert np.allclose(dat, DAT) def test_tbl_accumulate(sisl_tmp): DAT = np.arange(12).reshape(3, 4) + 1 - io = tableSile(sisl_tmp('t.dat', _dir), 'w') + io = tableSile(sisl_tmp("t.dat", _dir), "w") with io: - io.write_data(header='Hello') + io.write_data(header="Hello") for d in DAT.T: # Write a row at a time (otherwise 1D data will be written as a single column) io.write_data(d.reshape(-1, 1)) - dat, header = tableSile(io.file, 'r').read_data(ret_header=True) + dat, header = tableSile(io.file, "r").read_data(ret_header=True) assert np.allclose(dat, DAT) - assert header.index('Hello') >= 0 + assert header.index("Hello") >= 0 def test_tbl_accumulate_1d(sisl_tmp): DAT = np.arange(12).reshape(3, 4) + 1 - io = tableSile(sisl_tmp('t.dat', _dir), 'w') + io = tableSile(sisl_tmp("t.dat", _dir), "w") with io: - io.write_data(header='Hello') + io.write_data(header="Hello") for d in DAT: # 1D data will be written as a single column io.write_data(d) - dat, header = tableSile(io.file, 'r').read_data(ret_header=True) + dat, header = tableSile(io.file, "r").read_data(ret_header=True) assert dat.ndim == 1 assert np.allclose(dat, DAT.ravel()) - assert header.index('Hello') >= 0 + assert header.index("Hello") >= 0 -@pytest.mark.parametrize("delimiter", ['\t', ' ', ',', ':', 'M']) +@pytest.mark.parametrize("delimiter", ["\t", " ", ",", ":", "M"]) def test_tbl5(sisl_tmp, delimiter): dat0 = np.arange(8) dat1 = np.arange(8) + 1 DAT = np.stack([dat0, dat1]) - io = tableSile(sisl_tmp('t.dat', _dir), 'w') + io = tableSile(sisl_tmp("t.dat", _dir), "w") io.write_data(dat0, dat1, delimiter=delimiter) - if delimiter in ['\t', ' ', ',']: - dat = tableSile(io.file, 'r').read_data() + if delimiter in ["\t", " ", ","]: + dat = tableSile(io.file, "r").read_data() assert np.allclose(dat, DAT) - dat = tableSile(io.file, 'r').read_data(delimiter=delimiter) + dat = tableSile(io.file, "r").read_data(delimiter=delimiter) assert np.allclose(dat, DAT) diff --git a/src/sisl/io/tests/test_tb.py b/src/sisl/io/tests/test_tb.py index e44302454d..3328cb339f 100644 --- a/src/sisl/io/tests/test_tb.py +++ b/src/sisl/io/tests/test_tb.py @@ -9,12 +9,12 @@ from sisl.io.ham import * pytestmark = [pytest.mark.io, pytest.mark.generic] -_dir = osp.join('sisl', 'io') +_dir = osp.join("sisl", "io") def test_ham1(sisl_tmp, sisl_system): - f = sisl_tmp('gr.ham', _dir) - sisl_system.g.write(hamiltonianSile(f, 'w')) + f = sisl_tmp("gr.ham", _dir) + sisl_system.g.write(hamiltonianSile(f, "w")) g = hamiltonianSile(f).read_geometry() # Assert they are the same @@ -24,7 +24,7 @@ def test_ham1(sisl_tmp, sisl_system): def test_ham2(sisl_tmp, sisl_system): - f = sisl_tmp('gr.ham', _dir) - sisl_system.ham.write(hamiltonianSile(f, 'w')) + f = sisl_tmp("gr.ham", _dir) + sisl_system.ham.write(hamiltonianSile(f, "w")) ham = hamiltonianSile(f).read_hamiltonian() assert ham.spsame(sisl_system.ham) diff --git a/src/sisl/io/tests/test_xsf.py b/src/sisl/io/tests/test_xsf.py index 64ac1de5e8..85ed6b6472 100644 --- a/src/sisl/io/tests/test_xsf.py +++ b/src/sisl/io/tests/test_xsf.py @@ -11,11 +11,11 @@ from sisl.io.xsf import * pytestmark = [pytest.mark.io, pytest.mark.generic] -_dir = osp.join('sisl', 'io') +_dir = osp.join("sisl", "io") def test_default(sisl_tmp): - f = sisl_tmp('GRID_default.xsf', _dir) + f = sisl_tmp("GRID_default.xsf", _dir) grid = Grid(0.2) grid.grid = np.random.rand(*grid.shape) grid.write(f) @@ -23,7 +23,7 @@ def test_default(sisl_tmp): def test_default_size(sisl_tmp): - f = sisl_tmp('GRID_default_size.xsf', _dir) + f = sisl_tmp("GRID_default_size.xsf", _dir) grid = Grid(0.2, lattice=2.0) grid.grid = np.random.rand(*grid.shape) grid.write(f) @@ -31,8 +31,12 @@ def test_default_size(sisl_tmp): def test_geometry(sisl_tmp): - f = sisl_tmp('GRID_geometry.xsf', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) + f = sisl_tmp("GRID_geometry.xsf", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) grid = Grid(0.2, geometry=geom) grid.grid = np.random.rand(*grid.shape) grid.write(f) @@ -40,18 +44,26 @@ def test_geometry(sisl_tmp): def test_imaginary(sisl_tmp): - f = sisl_tmp('GRID_imag.xsf', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) + f = sisl_tmp("GRID_imag.xsf", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) grid = Grid(0.2, geometry=geom, dtype=np.complex128) - grid.grid = np.random.rand(*grid.shape) + 1j*np.random.rand(*grid.shape) + grid.grid = np.random.rand(*grid.shape) + 1j * np.random.rand(*grid.shape) grid.write(f) assert not grid.geometry is None def test_axsf_geoms(sisl_tmp): - f = sisl_tmp('multigeom_nodata.axsf', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) - geoms = [geom.move((i/10, i/10, i/10)) for i in range(3)] + f = sisl_tmp("multigeom_nodata.axsf", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) + geoms = [geom.move((i / 10, i / 10, i / 10)) for i in range(3)] with xsfSile(f, "w", steps=3) as s: for g in geoms: @@ -83,9 +95,13 @@ def test_axsf_geoms(sisl_tmp): def test_axsf_data(sisl_tmp): - f = sisl_tmp('multigeom_data.axsf', _dir) - geom = Geometry(np.random.rand(10, 3), np.random.randint(1, 70, 10), lattice=[10, 10, 10, 45, 60, 90]) - geoms = [geom.move((i/10, i/10, i/10)) for i in range(3)] + f = sisl_tmp("multigeom_data.axsf", _dir) + geom = Geometry( + np.random.rand(10, 3), + np.random.randint(1, 70, 10), + lattice=[10, 10, 10, 45, 60, 90], + ) + geoms = [geom.move((i / 10, i / 10, i / 10)) for i in range(3)] data = np.random.rand(3, 10, 3) with xsfSile(f, "w", steps=3) as s: @@ -102,4 +118,3 @@ def test_axsf_data(sisl_tmp): rgeoms, rdata = s.read_geometry[0](ret_data=True) assert geoms[0].equal(rgeoms) assert np.allclose(rdata, data[0]) - diff --git a/src/sisl/io/tests/test_xyz.py b/src/sisl/io/tests/test_xyz.py index fa98fe91e1..01384622fc 100644 --- a/src/sisl/io/tests/test_xyz.py +++ b/src/sisl/io/tests/test_xyz.py @@ -9,12 +9,12 @@ from sisl.io.xyz import * pytestmark = [pytest.mark.io, pytest.mark.generic] -_dir = osp.join('sisl', 'io') +_dir = osp.join("sisl", "io") def test_xyz1(sisl_tmp, sisl_system): - f = sisl_tmp('gr.xyz', _dir) - sisl_system.g.write(xyzSile(f, 'w')) + f = sisl_tmp("gr.xyz", _dir) + sisl_system.g.write(xyzSile(f, "w")) g = xyzSile(f).read_geometry() # Assert they are the same @@ -26,22 +26,24 @@ def test_xyz1(sisl_tmp, sisl_system): def test_xyz_sisl(sisl_tmp): - f = sisl_tmp('sisl.xyz', _dir) + f = sisl_tmp("sisl.xyz", _dir) - with open(f, 'w') as fh: - fh.write("""3 + with open(f, "w") as fh: + fh.write( + """3 sisl-version=1 nsc=1 1 3 cell=10 0 0 0 12 0 0 0 13 C 0.00000000 0.00000000 0.00000000 C 1.000000 0.00000000 0.00000000 C 2.00000 0.00000000 0.00000000 -""") +""" + ) g = xyzSile(f).read_geometry() # Assert they are the same assert np.allclose(g.cell, [[10, 0, 0], [0, 12, 0], [0, 0, 13]]) assert np.allclose(g.xyz[:, 0], [0, 1, 2]) - assert np.allclose(g.xyz[:, 1], 0.) - assert np.allclose(g.xyz[:, 2], 0.) + assert np.allclose(g.xyz[:, 1], 0.0) + assert np.allclose(g.xyz[:, 2], 0.0) assert np.allclose(g.nsc, [1, 1, 3]) g = xyzSile(f).read_geometry(lattice=[10, 11, 13]) @@ -49,46 +51,52 @@ def test_xyz_sisl(sisl_tmp): def test_xyz_ase(sisl_tmp): - f = sisl_tmp('ase.xyz', _dir) - with open(f, 'w') as fh: - fh.write("""3 + f = sisl_tmp("ase.xyz", _dir) + with open(f, "w") as fh: + fh.write( + """3 Lattice="10 0 0 0 12 0 0 0 13" Properties=species:S:1:pos:R:3 pbc="F F T" C 0.00000000 0.00000000 0.00000000 C 1.000000 0.00000000 0.00000000 C 2.00000 0.00000000 0.00000000 -""") +""" + ) g = xyzSile(f).read_geometry() # Assert they are the same assert np.allclose(g.cell, [[10, 0, 0], [0, 12, 0], [0, 0, 13]]) assert np.allclose(g.xyz[:, 0], [0, 1, 2]) - assert np.allclose(g.xyz[:, 1], 0.) - assert np.allclose(g.xyz[:, 2], 0.) + assert np.allclose(g.xyz[:, 1], 0.0) + assert np.allclose(g.xyz[:, 2], 0.0) assert np.allclose(g.nsc, [1, 1, 1]) assert np.allclose(g.pbc, [False, False, True]) def test_xyz_arbitrary(sisl_tmp): - f = sisl_tmp('ase.xyz', _dir) - with open(f, 'w') as fh: - fh.write("""3 + f = sisl_tmp("ase.xyz", _dir) + with open(f, "w") as fh: + fh.write( + """3 C 0.00000000 0.00000000 0.00000000 C 1.000000 0.00000000 0.00000000 C 2.00000 0.00000000 0.00000000 -""") +""" + ) g = xyzSile(f).read_geometry() # Assert they are the same assert np.allclose(g.xyz[:, 0], [0, 1, 2]) - assert np.allclose(g.xyz[:, 1], 0.) - assert np.allclose(g.xyz[:, 2], 0.) + assert np.allclose(g.xyz[:, 1], 0.0) + assert np.allclose(g.xyz[:, 2], 0.0) assert np.allclose(g.nsc, [1, 1, 1]) + def test_xyz_multiple(sisl_tmp): - f = sisl_tmp('sisl_multiple.xyz', _dir) - with open(f, 'w') as fh: - fh.write("""1 + f = sisl_tmp("sisl_multiple.xyz", _dir) + with open(f, "w") as fh: + fh.write( + """1 C 0.00000000 0.00000000 0.00000000 2 @@ -100,7 +108,8 @@ def test_xyz_multiple(sisl_tmp): C 0.00000000 0.00000000 0.00000000 C 1.000000 0.00000000 0.00000000 C 2.00000 0.00000000 0.00000000 -""") +""" + ) g = xyzSile(f).read_geometry() assert g.na == 1 g = xyzSile(f).read_geometry[1]() diff --git a/src/sisl/io/vasp/__init__.py b/src/sisl/io/vasp/__init__.py index 8084b9b961..5e76f00e1c 100644 --- a/src/sisl/io/vasp/__init__.py +++ b/src/sisl/io/vasp/__init__.py @@ -13,7 +13,7 @@ stdoutSileVASP """ -from .sile import * # isort: split +from .sile import * # isort: split from .car import * from .chg import * from .doscar import * diff --git a/src/sisl/io/vasp/car.py b/src/sisl/io/vasp/car.py index 07f765444b..543e82b5c8 100644 --- a/src/sisl/io/vasp/car.py +++ b/src/sisl/io/vasp/car.py @@ -9,27 +9,28 @@ from sisl.messages import warn from ..sile import add_sile, sile_fh_open, sile_raise_write + # Import sile objects from .sile import SileVASP -__all__ = ['carSileVASP'] +__all__ = ["carSileVASP"] @set_module("sisl.io.vasp") class carSileVASP(SileVASP): - """ CAR VASP files for defining geomtries + """CAR VASP files for defining geomtries This file-object handles both POSCAR and CONTCAR files """ def _setup(self, *args, **kwargs): - """ Setup the `carSile` after initialization """ + """Setup the `carSile` after initialization""" super()._setup(*args, **kwargs) - self._scale = 1. + self._scale = 1.0 @sile_fh_open() def write_geometry(self, geometry, dynamic=True, group_species=False): - r""" Writes the geometry to the contained file + r"""Writes the geometry to the contained file Parameters ---------- @@ -65,13 +66,13 @@ def write_geometry(self, geometry, dynamic=True, group_species=False): idx = _a.arangei(len(geometry)) # LABEL - self._write('sisl output\n') + self._write("sisl output\n") # Scale - self._write(' 1.\n') + self._write(" 1.\n") # Write unit-cell - fmt = (' ' + '{:18.9f}' * 3) + '\n' + fmt = (" " + "{:18.9f}" * 3) + "\n" for i in range(3): self._write(fmt.format(*geometry.cell[i])) @@ -81,7 +82,7 @@ def write_geometry(self, geometry, dynamic=True, group_species=False): ia = 0 while ia < geometry.na: atom = geometry.atoms[ia] - #specie = geometry.atoms.specie[ia] + # specie = geometry.atoms.specie[ia] ia_end = (np.diff(geometry.atoms.specie[ia:]) != 0).nonzero()[0] if len(ia_end) == 0: # remaining atoms @@ -92,34 +93,38 @@ def write_geometry(self, geometry, dynamic=True, group_species=False): d.append(ia_end - ia) ia += d[-1] - fmt = ' {:s}' * len(d) + '\n' + fmt = " {:s}" * len(d) + "\n" self._write(fmt.format(*s)) - fmt = ' {:d}' * len(d) + '\n' + fmt = " {:d}" * len(d) + "\n" self._write(fmt.format(*d)) if dynamic is None: # We write in direct mode dynamic = [None] * len(geometry) + def todyn(fix): - return '\n' + return "\n" + else: - self._write('Selective dynamics\n') - b2s = {True: 'T', False: 'F'} + self._write("Selective dynamics\n") + b2s = {True: "T", False: "F"} + def todyn(fix): if isinstance(fix, bool): - return ' {0} {0} {0}\n'.format(b2s[fix]) - return ' {} {} {}\n'.format(b2s[fix[0]], b2s[fix[1]], b2s[fix[2]]) - self._write('Cartesian\n') + return " {0} {0} {0}\n".format(b2s[fix]) + return " {} {} {}\n".format(b2s[fix[0]], b2s[fix[1]], b2s[fix[2]]) + + self._write("Cartesian\n") if isinstance(dynamic, bool): dynamic = [dynamic] * len(geometry) - fmt = '{:18.9f}' * 3 + fmt = "{:18.9f}" * 3 for ia in geometry: self._write(fmt.format(*geometry.xyz[ia, :]) + todyn(dynamic[idx[ia]])) @sile_fh_open(True) def read_lattice(self): - """ Returns `Lattice` object from the CONTCAR/POSCAR file """ + """Returns `Lattice` object from the CONTCAR/POSCAR file""" # read first line self.readline() # LABEL @@ -136,7 +141,7 @@ def read_lattice(self): @sile_fh_open() def read_geometry(self, ret_dynamic=False): - r""" Returns Geometry object from the CONTCAR/POSCAR file + r"""Returns Geometry object from the CONTCAR/POSCAR file Possibly also return the dynamics (if present). @@ -159,25 +164,28 @@ def read_geometry(self, ret_dynamic=False): # We have no species... # We default to consecutive elements in the # periodic table. - species = [i+1 for i in range(len(species_count))] - err = '\n'.join([ - "POSCAR best format:", - " ", - " <#Specie-1> <#Specie-2>", - "Format not found, the species are defaulted to the first elements of the periodic table."]) + species = [i + 1 for i in range(len(species_count))] + err = "\n".join( + [ + "POSCAR best format:", + " ", + " <#Specie-1> <#Specie-2>", + "Format not found, the species are defaulted to the first elements of the periodic table.", + ] + ) warn(err) # Create list of atoms to be used subsequently - atom = [Atom[spec] - for spec, nsp in zip(species, species_count) - for i in range(nsp)] + atom = [ + Atom[spec] for spec, nsp in zip(species, species_count) for i in range(nsp) + ] # Number of atoms na = len(atom) # check whether this is Selective Dynamics opt = self.readline() - if opt[0] in 'Ss': + if opt[0] in "Ss": dynamics = True # pre-create the dynamic list dynamic = np.empty([na, 3], dtype=np.bool_) @@ -189,7 +197,7 @@ def read_geometry(self, ret_dynamic=False): # Check whether this is in fractional or direct # coordinates (Direct == fractional) cart = False - if opt[0] in 'CcKk': + if opt[0] in "CcKk": cart = True xyz = _a.emptyd([na, 3]) @@ -197,7 +205,7 @@ def read_geometry(self, ret_dynamic=False): line = self.readline().split() xyz[ia, :] = list(map(float, line[:3])) if dynamics: - dynamic[ia] = list(map(lambda x: x.lower() == 't', line[3:6])) + dynamic[ia] = list(map(lambda x: x.lower() == "t", line[3:6])) if cart: # The unit of the coordinates are cartesian @@ -212,12 +220,12 @@ def read_geometry(self, ret_dynamic=False): return geom def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('CAR', carSileVASP, gzip=True) -add_sile('POSCAR', carSileVASP, gzip=True) -add_sile('CONTCAR', carSileVASP, gzip=True) +add_sile("CAR", carSileVASP, gzip=True) +add_sile("POSCAR", carSileVASP, gzip=True) +add_sile("CONTCAR", carSileVASP, gzip=True) diff --git a/src/sisl/io/vasp/chg.py b/src/sisl/io/vasp/chg.py index ceaf878dc0..359ddea48f 100644 --- a/src/sisl/io/vasp/chg.py +++ b/src/sisl/io/vasp/chg.py @@ -18,14 +18,14 @@ @set_module("sisl.io.vasp") class chgSileVASP(carSileVASP): - """ Charge density plus geometry + """Charge density plus geometry This file-object handles the charge-density from VASP """ @sile_fh_open(True) def read_grid(self, index=0, dtype=np.float64, **kwargs): - """ Reads the charge density from the file and returns with a grid (plus geometry) + """Reads the charge density from the file and returns with a grid (plus geometry) Parameters ---------- @@ -83,9 +83,9 @@ def read_grid(self, index=0, dtype=np.float64, **kwargs): # Cut size before proceeding (otherwise it *may* fail) vals = np.array(vals).astype(dtype, copy=False) if is_index: - val = vals[n * index:n * (index+1)].reshape(nz, ny, nx) + val = vals[n * index : n * (index + 1)].reshape(nz, ny, nx) else: - vals = vals[:n * max_index].reshape(-1, nz, ny, nx) + vals = vals[: n * max_index].reshape(-1, nz, ny, nx) val = grid_reduce_indices(vals, index, axis=0) del vals diff --git a/src/sisl/io/vasp/doscar.py b/src/sisl/io/vasp/doscar.py index 6e194d9ba3..dedf087989 100644 --- a/src/sisl/io/vasp/doscar.py +++ b/src/sisl/io/vasp/doscar.py @@ -9,16 +9,16 @@ from ..sile import add_sile, sile_fh_open from .sile import SileVASP -__all__ = ['doscarSileVASP'] +__all__ = ["doscarSileVASP"] @set_module("sisl.io.vasp") class doscarSileVASP(SileVASP): - """ Density of states output """ + """Density of states output""" @sile_fh_open(True) def read_fermi_level(self): - r""" Query the Fermi-level contained in the file + r"""Query the Fermi-level contained in the file Returns ------- @@ -34,7 +34,7 @@ def read_fermi_level(self): @sile_fh_open() def read_data(self): - r""" Read DOS, as calculated and written by VASP + r"""Read DOS, as calculated and written by VASP Returns ------- @@ -50,8 +50,8 @@ def read_data(self): self.readline() # ' CAR ' self.readline() # name line = self.readline().split() - #Emax = float(line[0]) - #Emin = float(line[1]) + # Emax = float(line[0]) + # Emin = float(line[1]) NE = int(line[2]) Ef = float(line[3]) @@ -61,12 +61,12 @@ def read_data(self): ns = (len(line) - 1) // 2 DOS = np.empty([ns, NE], np.float32) E[0] = line[0] - DOS[:, 0] = line[1:ns+1] + DOS[:, 0] = line[1 : ns + 1] for ie in range(1, NE): line = arrayf(self.readline().split()) E[ie] = line[0] - DOS[:, ie] = line[1:ns+1] + DOS[:, ie] = line[1 : ns + 1] return E - Ef, DOS -add_sile('DOSCAR', doscarSileVASP, gzip=True) +add_sile("DOSCAR", doscarSileVASP, gzip=True) diff --git a/src/sisl/io/vasp/eigenval.py b/src/sisl/io/vasp/eigenval.py index 15e03d031b..48ab3b1018 100644 --- a/src/sisl/io/vasp/eigenval.py +++ b/src/sisl/io/vasp/eigenval.py @@ -6,19 +6,20 @@ from sisl._internal import set_module from ..sile import add_sile, sile_fh_open + # Import sile objects from .sile import SileVASP -__all__ = ['eigenvalSileVASP'] +__all__ = ["eigenvalSileVASP"] @set_module("sisl.io.vasp") class eigenvalSileVASP(SileVASP): - """ Kohn-Sham eigenvalues """ + """Kohn-Sham eigenvalues""" @sile_fh_open() def read_data(self, k=False): - r""" Read eigenvalues, as calculated and written by VASP + r"""Read eigenvalues, as calculated and written by VASP Parameters ---------- @@ -49,17 +50,17 @@ def read_data(self, k=False): w = np.empty([nk], np.float64) for ik in range(nk): self.readline() # empty line - line = self.readline().split() # k-point, weight + line = self.readline().split() # k-point, weight kk[ik, :] = list(map(float, line[:3])) w[ik] = float(line[3]) for ib in range(nb): # band, eig_UP, eig_DOWN, pop_UP, pop_DOWN # We currently neglect the populations - E = map(float, self.readline().split()[1:ns+1]) + E = map(float, self.readline().split()[1 : ns + 1]) eigs[:, ik, ib] = list(E) if k: return eigs, kk, w return eigs -add_sile('EIGENVAL', eigenvalSileVASP, gzip=True) +add_sile("EIGENVAL", eigenvalSileVASP, gzip=True) diff --git a/src/sisl/io/vasp/locpot.py b/src/sisl/io/vasp/locpot.py index b81d369a6c..e4801539b9 100644 --- a/src/sisl/io/vasp/locpot.py +++ b/src/sisl/io/vasp/locpot.py @@ -13,19 +13,19 @@ from .car import carSileVASP from .sile import SileVASP -__all__ = ['locpotSileVASP'] +__all__ = ["locpotSileVASP"] @set_module("sisl.io.vasp") class locpotSileVASP(carSileVASP): - """ Electrostatic (or total) potential plus geometry + """Electrostatic (or total) potential plus geometry This file-object handles the electrostatic(total) potential from VASP """ @sile_fh_open(True) def read_grid(self, index=0, dtype=np.float64, **kwargs): - """ Reads the potential (in eV) from the file and returns with a grid (plus geometry) + """Reads the potential (in eV) from the file and returns with a grid (plus geometry) Parameters ---------- @@ -83,9 +83,9 @@ def read_grid(self, index=0, dtype=np.float64, **kwargs): # Cut size before proceeding (otherwise it *may* fail) vals = np.array(vals).astype(dtype).ravel() if is_index: - val = vals[n * index:n * (index + 1)].reshape(nz, ny, nx) + val = vals[n * index : n * (index + 1)].reshape(nz, ny, nx) else: - vals = vals[:n * max_index].reshape(-1, nz, ny, nx) + vals = vals[: n * max_index].reshape(-1, nz, ny, nx) val = grid_reduce_indices(vals, index, axis=0) del vals @@ -101,4 +101,4 @@ def read_grid(self, index=0, dtype=np.float64, **kwargs): return grid -add_sile('LOCPOT', locpotSileVASP, gzip=True) +add_sile("LOCPOT", locpotSileVASP, gzip=True) diff --git a/src/sisl/io/vasp/sile.py b/src/sisl/io/vasp/sile.py index fe575afea1..73046adcfe 100644 --- a/src/sisl/io/vasp/sile.py +++ b/src/sisl/io/vasp/sile.py @@ -13,7 +13,7 @@ def _geometry_group(geometry, ret_index=False): - r""" Order atoms in geometry according to species such that all of one specie is consecutive + r"""Order atoms in geometry according to species such that all of one specie is consecutive When creating VASP input files (`poscarSileVASP` for instance) the equivalent ``POTCAR`` file needs to contain the pseudos for each specie as they are provided @@ -44,7 +44,7 @@ def _geometry_group(geometry, ret_index=False): ia = 0 for _, idx_s in geometry.atoms.iter(species=True): - idx[ia:ia + len(idx_s)] = idx_s + idx[ia : ia + len(idx_s)] = idx_s ia += len(idx_s) assert ia == na diff --git a/src/sisl/io/vasp/stdout.py b/src/sisl/io/vasp/stdout.py index 9b4c2d5aab..b1f8fee7dc 100644 --- a/src/sisl/io/vasp/stdout.py +++ b/src/sisl/io/vasp/stdout.py @@ -19,46 +19,72 @@ @set_module("sisl.io.vasp") class stdoutSileVASP(SileVASP): - """ Output file from VASP """ + """Output file from VASP""" _info_attributes_ = [ - _A("completed", r".*General timing and accounting", - lambda attr, match: lambda : True, default=lambda : False), - _A("accuracy_reached", r".*reached required accuracy", - lambda attr, match: lambda : True, default=lambda : False), + _A( + "completed", + r".*General timing and accounting", + lambda attr, match: lambda: True, + default=lambda: False, + ), + _A( + "accuracy_reached", + r".*reached required accuracy", + lambda attr, match: lambda: True, + default=lambda: False, + ), ] - @deprecation("stdoutSileVASP.completed is deprecated in favor of stdoutSileVASP.info.completed", "0.16.0") + @deprecation( + "stdoutSileVASP.completed is deprecated in favor of stdoutSileVASP.info.completed", + "0.16.0", + ) def completed(self): - """ True if the line "General timing and accounting" was found. """ + """True if the line "General timing and accounting" was found.""" return self.info.completed() - @deprecation("stdoutSileVASP.accuracy_reached is deprecated in favor of stdoutSileVASP.info.accuracy_reached", "0.16.0") + @deprecation( + "stdoutSileVASP.accuracy_reached is deprecated in favor of stdoutSileVASP.info.accuracy_reached", + "0.16.0", + ) def accuracy_reached(self): - """ True if the line "reached required accuracy" was found. """ + """True if the line "reached required accuracy" was found.""" return self.info.accuracy_reached() @sile_fh_open() def cpu_time(self, flag="General timing and accounting"): - """ Returns the consumed cpu time (in seconds) from a given section """ + """Returns the consumed cpu time (in seconds) from a given section""" if flag == "General timing and accounting": nskip, iplace = 3, 5 else: - raise ValueError(f"{self.__class__.__name__}.cpu_time unknown flag '{flag}'") + raise ValueError( + f"{self.__class__.__name__}.cpu_time unknown flag '{flag}'" + ) found = self.step_to(flag, allow_reread=False)[0] if found: for _ in range(nskip): line = self.readline() return float(line.split()[iplace]) - raise KeyError(f"{self.__class__.__name__}.cpu_time could not find flag '{flag}' in file") + raise KeyError( + f"{self.__class__.__name__}.cpu_time could not find flag '{flag}' in file" + ) @SileBinder() @sile_fh_open() - @deprecate_argument("all", None, "use read_energy[:]() instead to get all entries", from_version="0.14") - @deprecation("WARNING: direct calls to stdoutSileVASP.read_energy() no longer returns the last entry! Now the next block on file is returned.", from_version="0.14") + @deprecate_argument( + "all", + None, + "use read_energy[:]() instead to get all entries", + from_version="0.14", + ) + @deprecation( + "WARNING: direct calls to stdoutSileVASP.read_energy() no longer returns the last entry! Now the next block on file is returned.", + from_version="0.14", + ) def read_energy(self): - """ Reads an energy specification block from OUTCAR + """Reads an energy specification block from OUTCAR The function steps to the next occurrence of the "Free energy of the ion-electron system" segment @@ -101,10 +127,12 @@ def read_energy(self): } # read the energy tables - f = self.step_to("Free energy of the ion-electron system", allow_reread=False)[0] + f = self.step_to("Free energy of the ion-electron system", allow_reread=False)[ + 0 + ] if not f: return None - self.readline() # ----- + self.readline() # ----- line = self.readline() E = PropertyDict() while "----" not in line: @@ -127,7 +155,7 @@ def read_energy(self): @SileBinder() @sile_fh_open() def read_trajectory(self): - """ Reads cell+position+force data from OUTCAR for an ionic trajectory step + """Reads cell+position+force data from OUTCAR for an ionic trajectory step The function steps to the block defined by the "VOLUME and BASIS-vectors are now :" line to first read the cell vectors, then it steps to the "TOTAL-FORCE (eV/Angst)" segment @@ -142,17 +170,17 @@ def read_trajectory(self): if not f: return None for i in range(4): - self.readline() # skip 4 lines + self.readline() # skip 4 lines C = [] for i in range(3): line = self.readline() v = line.split() - C.append(v[:3]) # direct lattice vectors + C.append(v[:3]) # direct lattice vectors # read a position-force table f = self.step_to("TOTAL-FORCE (eV/Angst)", allow_reread=False)[0] if not f: return None - self.readline() # ----- + self.readline() # ----- P, F = [], [] line = self.readline() while "----" not in line: @@ -168,7 +196,9 @@ def read_trajectory(self): return step -outSileVASP = deprecation("outSileVASP has been deprecated in favor of stdoutSileVASP.", "0.15")(stdoutSileVASP) +outSileVASP = deprecation( + "outSileVASP has been deprecated in favor of stdoutSileVASP.", "0.15" +)(stdoutSileVASP) add_sile("OUTCAR", stdoutSileVASP, gzip=True) add_sile("vasp.out", stdoutSileVASP, case=False, gzip=True) diff --git a/src/sisl/io/vasp/tests/test_car.py b/src/sisl/io/vasp/tests/test_car.py index 31f3096917..f9410e4b9e 100644 --- a/src/sisl/io/vasp/tests/test_car.py +++ b/src/sisl/io/vasp/tests/test_car.py @@ -10,41 +10,29 @@ from sisl.io.vasp.car import * pytestmark = [pytest.mark.io, pytest.mark.vasp] -_dir = osp.join('sisl', 'io', 'vasp') +_dir = osp.join("sisl", "io", "vasp") def test_geometry_car_mixed(sisl_tmp): - f = sisl_tmp('test_read_write.POSCAR', _dir) - - atoms = [Atom[1], - Atom[2], - Atom[2], - Atom[1], - Atom[1], - Atom[2], - Atom[3]] + f = sisl_tmp("test_read_write.POSCAR", _dir) + + atoms = [Atom[1], Atom[2], Atom[2], Atom[1], Atom[1], Atom[2], Atom[3]] xyz = np.random.rand(len(atoms), 3) geom = Geometry(xyz, atoms, 100) - geom.write(carSileVASP(f, 'w')) + geom.write(carSileVASP(f, "w")) assert carSileVASP(f).read_geometry() == geom def test_geometry_car_group(sisl_tmp): - f = sisl_tmp('test_sort.POSCAR', _dir) - - atoms = [Atom[1], - Atom[2], - Atom[2], - Atom[1], - Atom[1], - Atom[2], - Atom[3]] + f = sisl_tmp("test_sort.POSCAR", _dir) + + atoms = [Atom[1], Atom[2], Atom[2], Atom[1], Atom[1], Atom[2], Atom[3]] xyz = np.random.rand(len(atoms), 3) geom = Geometry(xyz, atoms, 100) - geom.write(carSileVASP(f, 'w'), group_species=True) + geom.write(carSileVASP(f, "w"), group_species=True) assert carSileVASP(f).read_geometry() != geom geom = carSileVASP(f).geometry_group(geom) @@ -52,19 +40,19 @@ def test_geometry_car_group(sisl_tmp): def test_geometry_car_allsame(sisl_tmp): - f = sisl_tmp('test_read_write.POSCAR', _dir) + f = sisl_tmp("test_read_write.POSCAR", _dir) atoms = Atom[1] xyz = np.random.rand(10, 3) geom = Geometry(xyz, atoms, 100) - geom.write(carSileVASP(f, 'w')) + geom.write(carSileVASP(f, "w")) assert carSileVASP(f).read_geometry() == geom def test_geometry_car_dynamic(sisl_tmp): - f = sisl_tmp('test_dynamic.POSCAR', _dir) + f = sisl_tmp("test_dynamic.POSCAR", _dir) atoms = Atom[1] xyz = np.random.rand(10, 3) @@ -73,21 +61,21 @@ def test_geometry_car_dynamic(sisl_tmp): read = carSileVASP(f) # no dynamic (direct geometry) - geom.write(carSileVASP(f, 'w'), dynamic=None) + geom.write(carSileVASP(f, "w"), dynamic=None) g, dyn = read.read_geometry(ret_dynamic=True) assert dyn is None - geom.write(carSileVASP(f, 'w'), dynamic=False) + geom.write(carSileVASP(f, "w"), dynamic=False) g, dyn = read.read_geometry(ret_dynamic=True) assert not np.any(dyn) - geom.write(carSileVASP(f, 'w'), dynamic=True) + geom.write(carSileVASP(f, "w"), dynamic=True) g, dyn = read.read_geometry(ret_dynamic=True) assert np.all(dyn) dynamic = [False] * len(geom) dynamic[0] = [True, False, True] - geom.write(carSileVASP(f, 'w'), dynamic=dynamic) + geom.write(carSileVASP(f, "w"), dynamic=dynamic) g, dyn = read.read_geometry(ret_dynamic=True) assert np.array_equal(dynamic[0], dyn[0]) assert not np.any(dyn[1:]) diff --git a/src/sisl/io/vasp/tests/test_chg.py b/src/sisl/io/vasp/tests/test_chg.py index 659e92b5dc..11476857ac 100644 --- a/src/sisl/io/vasp/tests/test_chg.py +++ b/src/sisl/io/vasp/tests/test_chg.py @@ -9,11 +9,11 @@ from sisl.io.vasp.chg import * pytestmark = [pytest.mark.io, pytest.mark.vasp] -_dir = osp.join('sisl', 'io', 'vasp') +_dir = osp.join("sisl", "io", "vasp") def test_graphene_chg(sisl_files): - f = sisl_files(_dir, 'graphene', 'CHG') + f = sisl_files(_dir, "graphene", "CHG") grid = chgSileVASP(f).read_grid() gridf32 = chgSileVASP(f).read_grid(dtype=np.float32) geom = chgSileVASP(f).read_geometry() @@ -24,7 +24,7 @@ def test_graphene_chg(sisl_files): def test_graphene_chgcar(sisl_files): - f = sisl_files(_dir, 'graphene', 'CHGCAR') + f = sisl_files(_dir, "graphene", "CHGCAR") grid = chgSileVASP(f).read_grid() gridf32 = chgSileVASP(f).read_grid(index=0, dtype=np.float32) geom = chgSileVASP(f).read_geometry() @@ -35,7 +35,7 @@ def test_graphene_chgcar(sisl_files): def test_graphene_chgcar_index_float(sisl_files): - f = sisl_files(_dir, 'graphene', 'CHGCAR') + f = sisl_files(_dir, "graphene", "CHGCAR") grid = chgSileVASP(f).read_grid() gridh = chgSileVASP(f).read_grid(index=[0.5]) diff --git a/src/sisl/io/vasp/tests/test_doscar.py b/src/sisl/io/vasp/tests/test_doscar.py index a88d9ce53f..9f0b972968 100644 --- a/src/sisl/io/vasp/tests/test_doscar.py +++ b/src/sisl/io/vasp/tests/test_doscar.py @@ -9,9 +9,9 @@ from sisl.io.vasp.doscar import * pytestmark = [pytest.mark.io, pytest.mark.vasp] -_dir = osp.join('sisl', 'io', 'vasp') +_dir = osp.join("sisl", "io", "vasp") def test_graphene_doscar(sisl_files): - f = sisl_files(_dir, 'graphene', 'DOSCAR') + f = sisl_files(_dir, "graphene", "DOSCAR") E, DOS = doscarSileVASP(f).read_data() diff --git a/src/sisl/io/vasp/tests/test_eigenval.py b/src/sisl/io/vasp/tests/test_eigenval.py index 556940cf43..6ebd7d4f40 100644 --- a/src/sisl/io/vasp/tests/test_eigenval.py +++ b/src/sisl/io/vasp/tests/test_eigenval.py @@ -9,9 +9,9 @@ from sisl.io.vasp.eigenval import * pytestmark = [pytest.mark.io, pytest.mark.vasp] -_dir = osp.join('sisl', 'io', 'vasp') +_dir = osp.join("sisl", "io", "vasp") def test_read_eigenval(sisl_files): - f = sisl_files(_dir, 'graphene', 'EIGENVAL') + f = sisl_files(_dir, "graphene", "EIGENVAL") eigs = eigenvalSileVASP(f).read_data() diff --git a/src/sisl/io/vasp/tests/test_locpot.py b/src/sisl/io/vasp/tests/test_locpot.py index c0d348d2de..536a87e172 100644 --- a/src/sisl/io/vasp/tests/test_locpot.py +++ b/src/sisl/io/vasp/tests/test_locpot.py @@ -9,11 +9,11 @@ from sisl.io.vasp.locpot import * pytestmark = [pytest.mark.io, pytest.mark.vasp] -_dir = osp.join('sisl', 'io', 'vasp') +_dir = osp.join("sisl", "io", "vasp") def test_graphene_locpot(sisl_files): - f = sisl_files(_dir, 'graphene', 'LOCPOT') + f = sisl_files(_dir, "graphene", "LOCPOT") gridf64 = locpotSileVASP(f).read_grid() gridf32 = locpotSileVASP(f).read_grid(dtype=np.float32) geom = locpotSileVASP(f).read_geometry() @@ -24,7 +24,7 @@ def test_graphene_locpot(sisl_files): def test_graphene_locpot_index_float(sisl_files): - f = sisl_files(_dir, 'graphene', 'LOCPOT') + f = sisl_files(_dir, "graphene", "LOCPOT") grid = locpotSileVASP(f).read_grid() gridh = locpotSileVASP(f).read_grid(index=[0.5]) diff --git a/src/sisl/io/vasp/tests/test_stdout.py b/src/sisl/io/vasp/tests/test_stdout.py index 77de8cb6e4..c150456ee9 100644 --- a/src/sisl/io/vasp/tests/test_stdout.py +++ b/src/sisl/io/vasp/tests/test_stdout.py @@ -9,19 +9,19 @@ from sisl.io.vasp.stdout import stdoutSileVASP pytestmark = [pytest.mark.io, pytest.mark.vasp] -_dir = osp.join('sisl', 'io', 'vasp') +_dir = osp.join("sisl", "io", "vasp") def test_diamond_outcar_energies(sisl_files): - f = sisl_files(_dir, 'diamond', 'OUTCAR') + f = sisl_files(_dir, "diamond", "OUTCAR") f = stdoutSileVASP(f) E0 = f.read_energy() E = f.read_energy[-1]() Eall = f.read_energy[:]() - assert E0.sigma0 == 0.8569373 # first block - assert E.sigma0 == -18.18677613 # last block + assert E0.sigma0 == 0.8569373 # first block + assert E.sigma0 == -18.18677613 # last block assert E0 == Eall[0] assert E == Eall[-1] @@ -30,22 +30,22 @@ def test_diamond_outcar_energies(sisl_files): def test_diamond_outcar_cputime(sisl_files): - f = sisl_files(_dir, 'diamond', 'OUTCAR') + f = sisl_files(_dir, "diamond", "OUTCAR") f = stdoutSileVASP(f) - assert f.cpu_time() > 0. + assert f.cpu_time() > 0.0 assert f.info.completed() def test_diamond_outcar_completed(sisl_files): - f = sisl_files(_dir, 'diamond', 'OUTCAR') + f = sisl_files(_dir, "diamond", "OUTCAR") f = stdoutSileVASP(f) assert f.info.completed() def test_diamond_outcar_trajectory(sisl_files): - f = sisl_files(_dir, 'diamond', 'OUTCAR') + f = sisl_files(_dir, "diamond", "OUTCAR") f = stdoutSileVASP(f) step = f.read_trajectory() @@ -60,7 +60,7 @@ def test_diamond_outcar_trajectory(sisl_files): def test_graphene_relax_outcar_trajectory(sisl_files): - f = sisl_files(_dir, 'graphene_relax', 'OUTCAR') + f = sisl_files(_dir, "graphene_relax", "OUTCAR") f = stdoutSileVASP(f) step = f.read_trajectory[9]() @@ -84,7 +84,7 @@ def test_graphene_relax_outcar_trajectory(sisl_files): def test_graphene_md_outcar_trajectory(sisl_files): - f = sisl_files(_dir, 'graphene_md', 'OUTCAR') + f = sisl_files(_dir, "graphene_md", "OUTCAR") f = stdoutSileVASP(f) step = f.read_trajectory[99]() diff --git a/src/sisl/io/wannier90/__init__.py b/src/sisl/io/wannier90/__init__.py index 657c89b63d..dccfc40808 100644 --- a/src/sisl/io/wannier90/__init__.py +++ b/src/sisl/io/wannier90/__init__.py @@ -11,5 +11,5 @@ winSileWannier90 -- input file """ -from .sile import * # isort: split +from .sile import * # isort: split from .seedname import * diff --git a/src/sisl/io/wannier90/seedname.py b/src/sisl/io/wannier90/seedname.py index 092293647f..af538a9402 100644 --- a/src/sisl/io/wannier90/seedname.py +++ b/src/sisl/io/wannier90/seedname.py @@ -15,14 +15,15 @@ from sisl.unit import unit_convert from ..sile import * + # Import sile objects from .sile import SileWannier90 -__all__ = ['winSileWannier90'] +__all__ = ["winSileWannier90"] class winSileWannier90(SileWannier90): - """ Wannier seedname input file object + """Wannier seedname input file object This `Sile` enables easy interaction with the Wannier90 code. @@ -58,38 +59,40 @@ class winSileWannier90(SileWannier90): """ def _setup(self, *args, **kwargs): - """ Setup `winSileWannier90` after initialization """ + """Setup `winSileWannier90` after initialization""" super()._setup(*args, **kwargs) - self._comment = ['!', '#'] - self._seed = str(self.file).replace('.win', '') + self._comment = ["!", "#"] + self._seed = str(self.file).replace(".win", "") def _set_file(self, suffix=None): - """ Update readed file """ + """Update readed file""" if suffix is None: - self._file = Path(self._seed + '.win') + self._file = Path(self._seed + ".win") else: self._file = Path(self._seed + suffix) @sile_fh_open() def _read_lattice(self): - """ Deferred routine """ + """Deferred routine""" - f, l = self.step_to('unit_cell_cart', case=False) + f, l = self.step_to("unit_cell_cart", case=False) if not f: - raise ValueError("The unit-cell vectors could not be found in the seed-file.") + raise ValueError( + "The unit-cell vectors could not be found in the seed-file." + ) l = self.readline() lines = [] - while not l.startswith('end'): + while not l.startswith("end"): lines.append(l) l = self.readline() # Check whether the first element is a specification of the units pos_unit = lines[0].split() if len(pos_unit) > 2: - unit = 1. + unit = 1.0 else: - unit = unit_convert(pos_unit[0].capitalize(), 'Ang') + unit = unit_convert(pos_unit[0].capitalize(), "Ang") # Remove the line with the unit... lines.pop(0) @@ -101,7 +104,7 @@ def _read_lattice(self): return Lattice(cell * unit) def read_lattice(self): - """ Reads a `Lattice` and creates the Wannier90 cell """ + """Reads a `Lattice` and creates the Wannier90 cell""" # Reset self._set_file() @@ -109,7 +112,7 @@ def read_lattice(self): @sile_fh_open() def _read_geometry_centres(self, *args, **kwargs): - """ Defered routine """ + """Defered routine""" nc = int(self.readline()) @@ -122,25 +125,27 @@ def _read_geometry_centres(self, *args, **kwargs): for ia in range(nc): l = self.readline().split() sp[ia] = l.pop(0) - if sp[ia] == 'X': + if sp[ia] == "X": na = ia + 1 xyz[ia, :] = [float(k) for k in l[:3]] - return Geometry(xyz[:na, :], atoms='H') + return Geometry(xyz[:na, :], atoms="H") @sile_fh_open() def _read_geometry(self, lattice, *args, **kwargs): - """ Defered routine """ + """Defered routine""" is_frac = True - f, _ = self.step_to('atoms_frac', case=False) + f, _ = self.step_to("atoms_frac", case=False) if not f: is_frac = False self.fh.seek(0) - f, _ = self.step_to('atoms_cart', case=False) + f, _ = self.step_to("atoms_cart", case=False) if not f: - raise ValueError("The geometry coordinates (atoms_frac/cart) could not be found in the seed-file.") + raise ValueError( + "The geometry coordinates (atoms_frac/cart) could not be found in the seed-file." + ) # Species and coordinate list s = [] @@ -148,19 +153,19 @@ def _read_geometry(self, lattice, *args, **kwargs): # Read the next line to determine the units if is_frac: - unit = 1. + unit = 1.0 else: unit = self.readline() if len(unit.split()) > 1: l = unit.split() s.append(l[0]) xyz.append(list(map(float, l[1:4]))) - unit = 1. + unit = 1.0 else: - unit = unit_convert(unit.strip().capitalize(), 'Ang') + unit = unit_convert(unit.strip().capitalize(), "Ang") l = self.readline() - while not 'end' in l: + while not "end" in l: # Get the species and l = l.split() s.append(l[0]) @@ -175,14 +180,16 @@ def _read_geometry(self, lattice, *args, **kwargs): return Geometry(xyz, atoms=s, lattice=lattice) - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def read_geometry(self, *args, **kwargs): - """ Reads a `Geometry` and creates the Wannier90 cell """ + """Reads a `Geometry` and creates the Wannier90 cell""" # Read in the super-cell lattice = self.read_lattice() - self._set_file('_centres.xyz') + self._set_file("_centres.xyz") if self.file.is_file(): geom = self._read_geometry_centres() else: @@ -197,29 +204,31 @@ def read_geometry(self, *args, **kwargs): return geom @sile_fh_open() - def _write_lattice(self, lattice, fmt='.8f', *args, **kwargs): - """ Writes the supercel to the contained file """ + def _write_lattice(self, lattice, fmt=".8f", *args, **kwargs): + """Writes the supercel to the contained file""" # Check that we can write to the file sile_raise_write(self) - fmt_str = ' {{0:{0}}} {{1:{0}}} {{2:{0}}}\n'.format(fmt) + fmt_str = " {{0:{0}}} {{1:{0}}} {{2:{0}}}\n".format(fmt) - self._write('begin unit_cell_cart\n') - self._write(' Ang\n') + self._write("begin unit_cell_cart\n") + self._write(" Ang\n") self._write(fmt_str.format(*lattice.cell[0, :])) self._write(fmt_str.format(*lattice.cell[1, :])) self._write(fmt_str.format(*lattice.cell[2, :])) - self._write('end unit_cell_cart\n') + self._write("end unit_cell_cart\n") - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") - def write_lattice(self, lattice, fmt='.8f', *args, **kwargs): - """ Writes the supercell to the contained file """ + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) + def write_lattice(self, lattice, fmt=".8f", *args, **kwargs): + """Writes the supercell to the contained file""" self._set_file() self._write_lattice(lattice, fmt, *args, **kwargs) @sile_fh_open() - def _write_geometry(self, geom, fmt='.8f', *args, **kwargs): - """ Writes the geometry to the contained file """ + def _write_geometry(self, geom, fmt=".8f", *args, **kwargs): + """Writes the geometry to the contained file""" # Check that we can write to the file sile_raise_write(self) @@ -228,35 +237,35 @@ def _write_geometry(self, geom, fmt='.8f', *args, **kwargs): # and if it isn't 'a', then it cleans it... :( self._write_lattice(geom.lattice, fmt, *args, **kwargs) - fmt_str = ' {{1:2s}} {{2:{0}}} {{3:{0}}} {{4:{0}}} # {{0}}\n'.format(fmt) + fmt_str = " {{1:2s}} {{2:{0}}} {{3:{0}}} {{4:{0}}} # {{0}}\n".format(fmt) - if kwargs.get('frac', False): + if kwargs.get("frac", False): # Get the fractional coordinates fxyz = geom.fxyz[:, :] - self._write('begin atoms_frac\n') + self._write("begin atoms_frac\n") for ia, a, _ in geom.iter_species(): self._write(fmt_str.format(ia + 1, a.symbol, *fxyz[ia, :])) - self._write('end atoms_frac\n') + self._write("end atoms_frac\n") else: - self._write('begin atoms_cart\n') - self._write(' Ang\n') + self._write("begin atoms_cart\n") + self._write(" Ang\n") for ia, a, _ in geom.iter_species(): self._write(fmt_str.format(ia + 1, a.symbol, *geom.xyz[ia, :])) - self._write('end atoms_cart\n') + self._write("end atoms_cart\n") - def write_geometry(self, geom, fmt='.8f', *args, **kwargs): - """ Writes the geometry to the contained file """ + def write_geometry(self, geom, fmt=".8f", *args, **kwargs): + """Writes the geometry to the contained file""" self._set_file() self._write_geometry(geom, fmt, *args, **kwargs) @sile_fh_open() def _read_hamiltonian(self, geom, dtype=np.float64, **kwargs): - """ Reads a Hamiltonian + """Reads a Hamiltonian Reads the Hamiltonian model """ - cutoff = kwargs.get('cutoff', 0.00001) + cutoff = kwargs.get("cutoff", 0.00001) # Rewind to ensure we can read the entire matrix structure self.fh.seek(0) @@ -267,8 +276,11 @@ def _read_hamiltonian(self, geom, dtype=np.float64, **kwargs): # Number of orbitals no = int(self.readline()) if no != geom.no: - raise ValueError(self.__class__.__name__ + '.read_hamiltonian has found inconsistent number ' - 'of orbitals in _hr.dat vs the geometry. Remember to re-run Wannier90?') + raise ValueError( + self.__class__.__name__ + + ".read_hamiltonian has found inconsistent number " + "of orbitals in _hr.dat vs the geometry. Remember to re-run Wannier90?" + ) # Number of Wigner-Seitz degeneracy points nrpts = int(self.readline()) @@ -285,7 +297,7 @@ def _read_hamiltonian(self, geom, dtype=np.float64, **kwargs): ws.extend(list(map(int, self.readline().split()))) # Convert to numpy array and invert (for weights) - ws = 1. / np.array(ws, np.float64).flatten() + ws = 1.0 / np.array(ws, np.float64).flatten() # Figure out the number of supercells # and maintain the Hamiltonian in the ham list @@ -297,7 +309,7 @@ def _read_hamiltonian(self, geom, dtype=np.float64, **kwargs): while True: l = self.readline() - if l == '': + if l == "": break # Split here... @@ -323,7 +335,7 @@ def _read_hamiltonian(self, geom, dtype=np.float64, **kwargs): # column # Hr # Hi - ham.append(([iA, iB, iC], r-1, c-1, float(l[5]) * f, float(l[6]) * f)) + ham.append(([iA, iB, iC], r - 1, c - 1, float(l[5]) * f, float(l[6]) * f)) # Update number of super-cells geom.set_nsc([i * 2 + 1 for i in nsc]) @@ -335,7 +347,6 @@ def _read_hamiltonian(self, geom, dtype=np.float64, **kwargs): # populate the Hamiltonian by examining the cutoff value for isc, r, c, hr, hi in ham: - # Calculate the column corresponding to the # correct super-cell c = c + geom.sc_index(isc) * geom.no @@ -346,15 +357,15 @@ def _read_hamiltonian(self, geom, dtype=np.float64, **kwargs): Hi[r, c] = hi del ham - if np.dtype(dtype).kind == 'c': + if np.dtype(dtype).kind == "c": Hr = Hr.tocsr() Hi = Hi.tocsr() - Hr = Hr + 1j*Hi + Hr = Hr + 1j * Hi return Hamiltonian.fromsp(geom, Hr) def read_hamiltonian(self, *args, **kwargs): - """ Read the electronic structure of the Wannier90 output + """Read the electronic structure of the Wannier90 output Parameters ---------- @@ -366,17 +377,17 @@ def read_hamiltonian(self, *args, **kwargs): geom = self.read_geometry() # Set file - self._set_file('_hr.dat') + self._set_file("_hr.dat") H = self._read_hamiltonian(geom, *args, **kwargs) self._set_file() return H def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('win', winSileWannier90, gzip=True) +add_sile("win", winSileWannier90, gzip=True) diff --git a/src/sisl/io/wannier90/tests/test_seedname.py b/src/sisl/io/wannier90/tests/test_seedname.py index c5c988a99c..cf83777519 100644 --- a/src/sisl/io/wannier90/tests/test_seedname.py +++ b/src/sisl/io/wannier90/tests/test_seedname.py @@ -10,14 +10,15 @@ from sisl.io.wannier90 import * pytestmark = [pytest.mark.io, pytest.mark.wannier90, pytest.mark.w90] -_dir = osp.join('sisl', 'io', 'wannier90') +_dir = osp.join("sisl", "io", "wannier90") -@pytest.mark.parametrize("unit", ['', 'Ang\n', 'ang\n', 'bohr\n']) +@pytest.mark.parametrize("unit", ["", "Ang\n", "ang\n", "bohr\n"]) def test_seedname_read_frac(sisl_tmp, unit): - f = sisl_tmp('read_frac.win', _dir) - with open(f, 'w') as fh: - fh.write(""" + f = sisl_tmp("read_frac.win", _dir) + with open(f, "w") as fh: + fh.write( + """ begin unit_cell_cart {} 2. 0. 0. 0. 2. 0 @@ -27,23 +28,27 @@ def test_seedname_read_frac(sisl_tmp, unit): begin atoms_frac C 0.5 0.5 0.5 end -""".format(unit)) +""".format( + unit + ) + ) g = winSileWannier90(f).read_geometry() if len(unit) == 0: - unit = 'ang' - unit = units(unit.strip().capitalize(), 'Ang') + unit = "ang" + unit = units(unit.strip().capitalize(), "Ang") assert np.allclose(g.cell, np.identity(3) * 2 * unit) assert np.allclose(g.xyz, [1 * unit] * 3) -@pytest.mark.parametrize("unit_sc", ['', 'Ang\n', 'ang\n', 'bohr\n']) -@pytest.mark.parametrize("unit", ['', 'Ang\n', 'ang\n', 'bohr\n']) +@pytest.mark.parametrize("unit_sc", ["", "Ang\n", "ang\n", "bohr\n"]) +@pytest.mark.parametrize("unit", ["", "Ang\n", "ang\n", "bohr\n"]) def test_seedname_read_coord(sisl_tmp, unit_sc, unit): - f = sisl_tmp('read_coord.win', _dir) - with open(f, 'w') as fh: - fh.write(""" + f = sisl_tmp("read_coord.win", _dir) + with open(f, "w") as fh: + fh.write( + """ begin unit_cell_cart {} 2. 0. 0. 0. 2. 0 @@ -53,16 +58,19 @@ def test_seedname_read_coord(sisl_tmp, unit_sc, unit): begin atoms_cart {} C 0.5 0.5 0.5 end -""".format(unit_sc, unit)) +""".format( + unit_sc, unit + ) + ) g = winSileWannier90(f).read_geometry() if len(unit) == 0: - unit = 'ang' - unit = units(unit.strip().capitalize(), 'Ang') + unit = "ang" + unit = units(unit.strip().capitalize(), "Ang") if len(unit_sc) == 0: - unit_sc = 'ang' - unit_sc = units(unit_sc.strip().capitalize(), 'Ang') + unit_sc = "ang" + unit_sc = units(unit_sc.strip().capitalize(), "Ang") assert np.allclose(g.cell, np.identity(3) * 2 * unit_sc) assert np.allclose(g.xyz, [0.5 * unit] * 3) @@ -70,8 +78,8 @@ def test_seedname_read_coord(sisl_tmp, unit_sc, unit): @pytest.mark.parametrize("frac", [True, False]) def test_seedname_write_read(sisl_tmp, sisl_system, frac): - f = sisl_tmp('write_read.win', _dir) - sile = winSileWannier90(f, 'w') + f = sisl_tmp("write_read.win", _dir) + sile = winSileWannier90(f, "w") sile.write_geometry(sisl_system.g, frac=frac) g = winSileWannier90(f).read_geometry() diff --git a/src/sisl/io/xsf.py b/src/sisl/io/xsf.py index 6c496548d6..0e8bfe5c30 100644 --- a/src/sisl/io/xsf.py +++ b/src/sisl/io/xsf.py @@ -13,6 +13,7 @@ from sisl.utils import str_spec from ._multiple import SileBinder + # Import sile objects from .sile import * @@ -29,35 +30,41 @@ def _get_kw_index(key): def reset_values(*names_values, animsteps=False): if animsteps: + def reset(self): nonlocal names_values self._write_animsteps() for name, value in names_values: setattr(self, name, value) + else: + def reset(self): nonlocal names_values for name, value in names_values: setattr(self, name, value) + return reset def postprocess(*funcs): - """ Post-processes the returned value according to the funcs for multiple data """ + """Post-processes the returned value according to the funcs for multiple data""" + def post(ret): nonlocal funcs if isinstance(ret[0], tuple): return tuple(func(r) for r, func in zip(zip(*ret), funcs)) return funcs[0](ret) + return post - + # Implementation notice! # The XSF files are compatible with Vesta, but ONLY # if there are no empty lines! @set_module("sisl.io") class xsfSile(Sile): - """ XSF file for XCrySDen + """XSF file for XCrySDen When creating an XSF file one must denote how many geometries to write out. It is also necessary to use the xsf in a context manager, otherwise it will @@ -74,9 +81,9 @@ class xsfSile(Sile): """ def _setup(self, *args, **kwargs): - """ Setup the `xsfSile` after initialization """ + """Setup the `xsfSile` after initialization""" super()._setup(*args, **kwargs) - self._comment = ['#'] + self._comment = ["#"] if "w" in self._mode: self._geometry_max = kwargs.get("steps", 1) else: @@ -104,9 +111,11 @@ def _write_animsteps(self): self._write(f"ANIMSTEPS {self._geometry_max}\n") @sile_fh_open(reset=reset_values(("_geometry_write", 0), animsteps=True)) - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") - def write_lattice(self, lattice, fmt='.8f'): - """ Writes the supercell to the contained file + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) + def write_lattice(self, lattice, fmt=".8f"): + """Writes the supercell to the contained file Parameters ---------- @@ -120,7 +129,10 @@ def write_lattice(self, lattice, fmt='.8f'): # Write out top-header stuff from time import gmtime, strftime - self._write_once('# File created by: sisl {}\n#\n'.format(strftime("%Y-%m-%d", gmtime()))) + + self._write_once( + "# File created by: sisl {}\n#\n".format(strftime("%Y-%m-%d", gmtime())) + ) if all(lattice.nsc == 1): self._write_once("MOLECULE\n#\n") @@ -129,28 +141,28 @@ def write_lattice(self, lattice, fmt='.8f'): elif lattice.nsc[0] > 1: self._write_once("POLYMER\n#\n") else: - self._write_once('CRYSTAL\n#\n') + self._write_once("CRYSTAL\n#\n") - self._write_once('# Primitive lattice vectors:\n#\n') - self._write_key_index('PRIMVEC') + self._write_once("# Primitive lattice vectors:\n#\n") + self._write_key_index("PRIMVEC") # We write the cell coordinates as the cell coordinates - fmt_str = f'{{:{fmt}}} ' * 3 + '\n' + fmt_str = f"{{:{fmt}}} " * 3 + "\n" for i in (0, 1, 2): self._write(fmt_str.format(*lattice.cell[i, :])) # Convert the unit cell to a conventional cell (90-90-90)) # It seems this simply allows to store both formats in # the same file. However the below stuff is not correct. - #self._write_once('#\n# Conventional lattice vectors:\n#\n') - #self._write_key_index('CONVVEC') - #convcell = lattice.to.Cuboid(orthogonal=True)._v - #for i in [0, 1, 2]: + # self._write_once('#\n# Conventional lattice vectors:\n#\n') + # self._write_key_index('CONVVEC') + # convcell = lattice.to.Cuboid(orthogonal=True)._v + # for i in [0, 1, 2]: # self._write(fmt_str.format(*convcell[i, :])) @sile_fh_open(reset=reset_values(("_geometry_write", 0), animsteps=True)) - def write_geometry(self, geometry, fmt='.8f', data=None): - """ Writes the geometry to the contained file + def write_geometry(self, geometry, fmt=".8f", data=None): + """Writes the geometry to the contained file Parameters ---------- @@ -168,10 +180,10 @@ def write_geometry(self, geometry, fmt='.8f', data=None): if has_data: data.shape = (-1, 3) - self._write_once('#\n# Atomic coordinates (in primitive coordinates)\n#\n') + self._write_once("#\n# Atomic coordinates (in primitive coordinates)\n#\n") self._geometry_write += 1 self._write_key_index("PRIMCOORD") - self._write(f'{len(geometry)} 1\n') + self._write(f"{len(geometry)} 1\n") non_valid_Z = (geometry.atoms.Z <= 0).nonzero()[0] if len(non_valid_Z) > 0: @@ -179,13 +191,13 @@ def write_geometry(self, geometry, fmt='.8f', data=None): if has_data: fmt_str = ( - '{{0:3d}} {{1:{0}}} {{2:{0}}} {{3:{0}}} {{4:{0}}} {{5:{0}}} {{6:{0}}}\n' + "{{0:3d}} {{1:{0}}} {{2:{0}}} {{3:{0}}} {{4:{0}}} {{5:{0}}} {{6:{0}}}\n" ).format(fmt) for ia in geometry: tmp = np.append(geometry.xyz[ia, :], data[ia, :]) self._write(fmt_str.format(geometry.atoms[ia].Z, *tmp)) else: - fmt_str = '{{0:3d}} {{1:{0}}} {{2:{0}}} {{3:{0}}}\n'.format(fmt) + fmt_str = "{{0:3d}} {{1:{0}}} {{2:{0}}} {{3:{0}}}\n".format(fmt) for ia in geometry: self._write(fmt_str.format(geometry.atoms[ia].Z, *geometry.xyz[ia, :])) @@ -204,7 +216,7 @@ def _r_geometry_next(self, lattice=None, atoms=None, ret_data=False): data = None line = " " - while line != '': + while line != "": line = self.readline() if line.isspace(): @@ -264,7 +276,9 @@ def next(): break elif line.startswith("CONVCOORD"): - raise NotImplementedError(f"{self.__class__.__name__} does not implement reading CONVCOORD") + raise NotImplementedError( + f"{self.__class__.__name__} does not implement reading CONVCOORD" + ) elif line.startswith("CRYSTAL"): self._read_type = "CRYSTAL" @@ -293,10 +307,12 @@ def next(): cell = lattice elif typ == "MOLECULE": - cell = Lattice(np.diag(xyz.max(0) - xyz.min(0) + 10.)) + cell = Lattice(np.diag(xyz.max(0) - xyz.min(0) + 10.0)) if cell is None: - raise ValueError(f"{self.__class__.__name__} could not find lattice parameters.") + raise ValueError( + f"{self.__class__.__name__} could not find lattice parameters." + ) # overwrite the currently read cell self._read_cell = cell @@ -317,9 +333,11 @@ def next(): return geom @SileBinder(postprocess=postprocess(list, list)) - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def read_geometry(self, lattice=None, atoms=None, ret_data=False): - """ Geometry contained in file, and optionally the associated data + """Geometry contained in file, and optionally the associated data If the file contains more geometries, one can read multiple geometries by using the arguments `start`, `stop` and `step`. @@ -339,7 +357,7 @@ def read_geometry(self, lattice=None, atoms=None, ret_data=False): @sile_fh_open() def write_grid(self, *args, **kwargs): - """ Store grid(s) data to an XSF file + """Store grid(s) data to an XSF file Examples -------- @@ -365,48 +383,49 @@ def write_grid(self, *args, **kwargs): # for now we do not allow an animation with grid data... should this # even work? if self._geometry_max > 1: - raise NotImplementedError(f"{self.__class__.__name__}.write_grid not allowed in an animation file.") + raise NotImplementedError( + f"{self.__class__.__name__}.write_grid not allowed in an animation file." + ) - geom = kwargs.get('geometry', args[0].geometry) + geom = kwargs.get("geometry", args[0].geometry) if geom is None: geom = Geometry([0, 0, 0], AtomUnknown(999), lattice=args[0].lattice) self.write_geometry(geom) # Buffer size for writing - buffersize = min(kwargs.get('buffersize', 6144), args[0].grid.size) + buffersize = min(kwargs.get("buffersize", 6144), args[0].grid.size) # Format for precision - fmt = kwargs.get('fmt', '.5e') + fmt = kwargs.get("fmt", ".5e") - self._write('BEGIN_BLOCK_DATAGRID_3D\n') - name = kwargs.get('name', 'sisl_{}'.format(len(args))) + self._write("BEGIN_BLOCK_DATAGRID_3D\n") + name = kwargs.get("name", "sisl_{}".format(len(args))) # Transfer all spaces to underscores (no spaces allowed) - self._write(' ' + name.replace(' ', '_') + '\n') - _v3 = (('{:' + fmt + '} ') * 3).strip() + '\n' + self._write(" " + name.replace(" ", "_") + "\n") + _v3 = (("{:" + fmt + "} ") * 3).strip() + "\n" def write_cell(grid): # Now write the grid - self._write(' {} {} {}\n'.format(*grid.shape)) - self._write(' ' + _v3.format(*grid.origin)) - self._write(' ' + _v3.format(*grid.cell[0, :])) - self._write(' ' + _v3.format(*grid.cell[1, :])) - self._write(' ' + _v3.format(*grid.cell[2, :])) + self._write(" {} {} {}\n".format(*grid.shape)) + self._write(" " + _v3.format(*grid.origin)) + self._write(" " + _v3.format(*grid.cell[0, :])) + self._write(" " + _v3.format(*grid.cell[1, :])) + self._write(" " + _v3.format(*grid.cell[2, :])) for i, grid in enumerate(args): - if isinstance(grid, Grid): - name = kwargs.get(f'grid{i}', str(i)) + name = kwargs.get(f"grid{i}", str(i)) else: # it must be a tuple name, grid = grid - name = kwargs.get(f'grid{i}', name) + name = kwargs.get(f"grid{i}", name) is_complex = np.iscomplexobj(grid.grid) if is_complex: - self._write(f' BEGIN_DATAGRID_3D_real_{name}\n') + self._write(f" BEGIN_DATAGRID_3D_real_{name}\n") else: - self._write(f' BEGIN_DATAGRID_3D_{name}\n') + self._write(f" BEGIN_DATAGRID_3D_{name}\n") write_cell(grid) @@ -414,34 +433,44 @@ def write_cell(grid): # for y # for x # write... - _fmt = '{:' + fmt + '}\n' - for x in np.nditer(np.asarray(grid.grid.real.T, order='C').reshape(-1), flags=['external_loop', 'buffered'], - op_flags=[['readonly']], order='C', buffersize=buffersize): + _fmt = "{:" + fmt + "}\n" + for x in np.nditer( + np.asarray(grid.grid.real.T, order="C").reshape(-1), + flags=["external_loop", "buffered"], + op_flags=[["readonly"]], + order="C", + buffersize=buffersize, + ): self._write((_fmt * x.shape[0]).format(*x.tolist())) - self._write(' END_DATAGRID_3D\n') + self._write(" END_DATAGRID_3D\n") # Skip if not complex if not is_complex: continue - self._write(f' BEGIN_DATAGRID_3D_imag_{name}\n') + self._write(f" BEGIN_DATAGRID_3D_imag_{name}\n") write_cell(grid) - for x in np.nditer(np.asarray(grid.grid.imag.T, order='C').reshape(-1), flags=['external_loop', 'buffered'], - op_flags=[['readonly']], order='C', buffersize=buffersize): + for x in np.nditer( + np.asarray(grid.grid.imag.T, order="C").reshape(-1), + flags=["external_loop", "buffered"], + op_flags=[["readonly"]], + order="C", + buffersize=buffersize, + ): self._write((_fmt * x.shape[0]).format(*x.tolist())) - self._write(' END_DATAGRID_3D\n') + self._write(" END_DATAGRID_3D\n") - self._write('END_BLOCK_DATAGRID_3D\n') + self._write("END_BLOCK_DATAGRID_3D\n") def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) def ArgumentParser_out(self, p, *args, **kwargs): - """ Adds arguments only if this file is an output file + """Adds arguments only if this file is an output file Parameters ---------- @@ -449,29 +478,42 @@ def ArgumentParser_out(self, p, *args, **kwargs): the parser which gets amended the additional output options. """ import argparse + ns = kwargs.get("namespace", None) if ns is None: - class _(): + + class _: pass + ns = _() # We will add the vector data class VectorNoScale(argparse.Action): def __call__(self, parser, ns, no_value, option_string=None): setattr(ns, "_vector_scale", False) - p.add_argument("--no-vector-scale", "-nsv", nargs=0, - action=VectorNoScale, - help="""Do not modify vector components (same as --vector-scale 1.)""") + + p.add_argument( + "--no-vector-scale", + "-nsv", + nargs=0, + action=VectorNoScale, + help="""Do not modify vector components (same as --vector-scale 1.)""", + ) # Default to scale the vectors setattr(ns, "_vector_scale", True) # We will add the vector data class VectorScale(argparse.Action): def __call__(self, parser, ns, value, option_string=None): - setattr(ns, '_vector_scale', float(value)) - p.add_argument('--vector-scale', '-sv', metavar='SCALE', - action=VectorScale, - help="""Scale vector components by this factor.""") + setattr(ns, "_vector_scale", float(value)) + + p.add_argument( + "--vector-scale", + "-sv", + metavar="SCALE", + action=VectorScale, + help="""Scale vector components by this factor.""", + ) # We will add the vector data class Vectors(argparse.Action): @@ -479,7 +521,7 @@ def __call__(self, parser, ns, values, option_string=None): routine = values.pop(0) # Default input file - input_file = getattr(ns, '_input_file', None) + input_file = getattr(ns, "_input_file", None) # Figure out which of the segments are a file for i, val in enumerate(values): @@ -493,11 +535,12 @@ def __call__(self, parser, ns, values, option_string=None): # Try and read the vector from sisl.io import get_sile - input_sile = get_sile(input_file, mode='r') + + input_sile = get_sile(input_file, mode="r") vector = None - if hasattr(input_sile, f'read_{routine}'): - vector = getattr(input_sile, f'read_{routine}')(*values) + if hasattr(input_sile, f"read_{routine}"): + vector = getattr(input_sile, f"read_{routine}")(*values) if vector is None: # Try the read_data function @@ -516,22 +559,34 @@ def __call__(self, parser, ns, values, option_string=None): if vector is None: # Use title to capitalize - raise ValueError('{} could not be read from file: {}.'.format(routine.title(), input_file)) + raise ValueError( + "{} could not be read from file: {}.".format( + routine.title(), input_file + ) + ) if len(vector) != len(ns._geometry): - raise ValueError(f'read_{routine} could read from file: {input_file}, sizes does not conform to geometry.') - setattr(ns, '_vector', vector) - p.add_argument('--vector', '-v', metavar=('DATA', '*ARGS[, FILE]'), nargs='+', - action=Vectors, - help="""Adds vector arrows for each atom, first argument is type (force, moment, ...). + raise ValueError( + f"read_{routine} could read from file: {input_file}, sizes does not conform to geometry." + ) + setattr(ns, "_vector", vector) + + p.add_argument( + "--vector", + "-v", + metavar=("DATA", "*ARGS[, FILE]"), + nargs="+", + action=Vectors, + help="""Adds vector arrows for each atom, first argument is type (force, moment, ...). If the current input file contains the vectors no second argument is necessary, else the file containing the data is required as the last input. Any arguments inbetween are passed to the `read_data` function (in order). By default the vectors scaled by 1 / max(|V|) such that the longest vector has length 1. - """) + """, + ) -add_sile('xsf', xsfSile, case=False, gzip=True) -add_sile('axsf', xsfSile, case=False, gzip=True) +add_sile("xsf", xsfSile, case=False, gzip=True) +add_sile("axsf", xsfSile, case=False, gzip=True) diff --git a/src/sisl/io/xyz.py b/src/sisl/io/xyz.py index 0c276399bf..7ba3b6bc8b 100644 --- a/src/sisl/io/xyz.py +++ b/src/sisl/io/xyz.py @@ -21,10 +21,10 @@ @set_module("sisl.io") class xyzSile(Sile): - """ XYZ file object """ + """XYZ file object""" def _parse_lattice(self, header, xyz, lattice): - """Internal helper routine for extracting the lattice """ + """Internal helper routine for extracting the lattice""" if lattice is not None: return lattice @@ -59,13 +59,11 @@ def _parse_lattice(self, header, xyz, lattice): if "Origin" in header: origin = _a.fromiterd(header.pop("Origin").strip('"').split()).reshape(3) - return Lattice(cell, nsc=nsc, - origin=origin, - boundary_condition=bc) + return Lattice(cell, nsc=nsc, origin=origin, boundary_condition=bc) @sile_fh_open() - def write_geometry(self, geometry, fmt='.8f', comment=None): - """ Writes the geometry to the contained file + def write_geometry(self, geometry, fmt=".8f", comment=None): + """Writes the geometry to the contained file Parameters ---------- @@ -82,16 +80,18 @@ def write_geometry(self, geometry, fmt='.8f', comment=None): lattice = geometry.lattice # Write the number of atoms in the geometry - self._write(' {}\n'.format(len(geometry))) + self._write(" {}\n".format(len(geometry))) # Write out the cell information in the comment field # This contains the cell vectors in a single vector (3 + 3 + 3) # quantities, plus the number of supercells (3 ints) fields = [] - fields.append(('Lattice="' + f'{{:{fmt}}} ' * 9 + '"').format(*geometry.cell.ravel())) + fields.append( + ('Lattice="' + f"{{:{fmt}}} " * 9 + '"').format(*geometry.cell.ravel()) + ) nsc = geometry.nsc[:] fields.append('nsc="{} {} {}"'.format(*nsc)) - pbc = ['T' if n else 'F' for n in lattice.pbc] + pbc = ["T" if n else "F" for n in lattice.pbc] fields.append('pbc="{} {} {}"'.format(*pbc)) BC = BoundaryCondition.getitem bc = [f"{BC(n[0]).name} {BC(n[1]).name}" for n in lattice.boundary_condition] @@ -99,31 +99,33 @@ def write_geometry(self, geometry, fmt='.8f', comment=None): if comment is not None: fields.append(f'Comment="{comment}"') - self._write(' '.join(fields) + "\n") + self._write(" ".join(fields) + "\n") - fmt_str = '{{0:2s}} {{1:{0}}} {{2:{0}}} {{3:{0}}}\n'.format(fmt) + fmt_str = "{{0:2s}} {{1:{0}}} {{2:{0}}} {{3:{0}}}\n".format(fmt) for ia, a, _ in geometry.iter_species(): s = a.symbol - s = {'fa': 'Ds'}.get(s, s) + s = {"fa": "Ds"}.get(s, s) self._write(fmt_str.format(s, *geometry.xyz[ia, :])) def _r_geometry_skip(self, *args, **kwargs): - """ Read the geometry for a generic xyz file (not sisl, nor ASE) """ + """Read the geometry for a generic xyz file (not sisl, nor ASE)""" line = self.readline() - if line == '': + if line == "": return None na = int(line) line = self.readline - for _ in range(na+1): + for _ in range(na + 1): line() return na @SileBinder(skip_func=_r_geometry_skip) @sile_fh_open() - @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", from_version="0.15") + @deprecate_argument( + "sc", "lattice", "use lattice= instead of sc=", from_version="0.15" + ) def read_geometry(self, atoms=None, lattice=None): - """ Returns Geometry object from the XYZ file + """Returns Geometry object from the XYZ file Parameters ---------- @@ -133,7 +135,7 @@ def read_geometry(self, atoms=None, lattice=None): the lattice to be associated with the geometry """ line = self.readline() - if line == '': + if line == "": return None # Read number of atoms @@ -141,9 +143,7 @@ def read_geometry(self, atoms=None, lattice=None): # Read header, and try and convert to dictionary header = self.readline() - header = {k: v.strip('"') for k, v in - header_to_dict(header).items() - } + header = {k: v.strip('"') for k, v in header_to_dict(header).items()} # Read atoms and coordinates sp = [None] * na @@ -161,10 +161,10 @@ def read_geometry(self, atoms=None, lattice=None): return Geometry(xyz, atoms=sp, lattice=lattice) def ArgumentParser(self, p=None, *args, **kwargs): - """ Returns the arguments that is available for this Sile """ + """Returns the arguments that is available for this Sile""" newkw = Geometry._ArgumentParser_args_single() newkw.update(kwargs) return self.read_geometry().ArgumentParser(p, *args, **newkw) -add_sile('xyz', xyzSile, case=False, gzip=True) +add_sile("xyz", xyzSile, case=False, gzip=True) diff --git a/src/sisl/lattice.py b/src/sisl/lattice.py index c5dbe00419..4ee7a30af6 100644 --- a/src/sisl/lattice.py +++ b/src/sisl/lattice.py @@ -46,21 +46,24 @@ class BoundaryCondition(IntEnum): @classmethod def getitem(cls, key): - """Search for a specific integer entry by value, and not by name """ + """Search for a specific integer entry by value, and not by name""" if isinstance(key, cls): return key if isinstance(key, bool): if key: return cls.PERIODIC - raise ValueError(f"{cls.__name__}.getitem does not allow False, which BC should this refer to?") + raise ValueError( + f"{cls.__name__}.getitem does not allow False, which BC should this refer to?" + ) if isinstance(key, str): key = key.upper() if len(key) == 1: - key = {"U": "UNKNOWN", - "P": "PERIODIC", - "D": "DIRICHLET", - "N": "NEUMANN", - "O": "OPEN", + key = { + "U": "UNKNOWN", + "P": "PERIODIC", + "D": "DIRICHLET", + "N": "NEUMANN", + "O": "OPEN", }[key] for bc in cls: if bc.name.startswith(key): @@ -71,22 +74,21 @@ def getitem(cls, key): return bc raise KeyError(f"{cls.__name__}.getitem could not find key={key}") + BoundaryConditionType = Union[BoundaryCondition, int, str, bool] -SeqBoundaryConditionType = Union[BoundaryConditionType, - Sequence[BoundaryConditionType]] +SeqBoundaryConditionType = Union[BoundaryConditionType, Sequence[BoundaryConditionType]] @set_module("sisl") -class Lattice(_Dispatchs, - dispatchs=[ - ("new", ClassDispatcher("new", - instance_dispatcher=TypeDispatcher)), - ("to", ClassDispatcher("to", - type_dispatcher=None)) - ], - when_subclassing="copy", - ): - r""" A cell class to retain lattice vectors and a supercell structure +class Lattice( + _Dispatchs, + dispatchs=[ + ("new", ClassDispatcher("new", instance_dispatcher=TypeDispatcher)), + ("to", ClassDispatcher("to", type_dispatcher=None)), + ], + when_subclassing="copy", +): + r"""A cell class to retain lattice vectors and a supercell structure The supercell structure is comprising the *primary* unit-cell and neighbouring unit-cells. The number of supercells is given by the attribute `nsc` which @@ -110,14 +112,18 @@ class Lattice(_Dispatchs, """ # We limit the scope of this Lattice object. - __slots__ = ('cell', '_origin', 'nsc', 'n_s', '_sc_off', '_isc_off', '_bc') + __slots__ = ("cell", "_origin", "nsc", "n_s", "_sc_off", "_isc_off", "_bc") #: Internal reference to `BoundaryCondition` for simpler short-hands BC = BoundaryCondition - def __init__(self, cell, nsc=None, origin=None, - boundary_condition: SeqBoundaryConditionType =BoundaryCondition.PERIODIC): - + def __init__( + self, + cell, + nsc=None, + origin=None, + boundary_condition: SeqBoundaryConditionType = BoundaryCondition.PERIODIC, + ): if nsc is None: nsc = [1, 1, 1] @@ -139,16 +145,16 @@ def __init__(self, cell, nsc=None, origin=None, @property def length(self) -> ndarray: - """ Length of each lattice vector """ + """Length of each lattice vector""" return fnorm(self.cell) @property def volume(self): - """ Volume of cell """ + """Volume of cell""" return abs(dot3(self.cell[0, :], cross3(self.cell[1, :], self.cell[2, :]))) def area(self, ax0, ax1): - """ Calculate the area spanned by the two axis `ax0` and `ax1` """ + """Calculate the area spanned by the two axis `ax0` and `ax1`""" return (cross3(self.cell[ax0, :], self.cell[ax1, :]) ** 2).sum() ** 0.5 @property @@ -170,17 +176,20 @@ def pbc(self) -> np.ndarray: @property def origin(self) -> ndarray: - """ Origin for the cell """ + """Origin for the cell""" return self._origin @origin.setter def origin(self, origin): - """ Set origin for the cell """ + """Set origin for the cell""" self._origin[:] = origin - @deprecation("toCuboid is deprecated, please use lattice.to['cuboid'](...) instead.", "0.15.0") + @deprecation( + "toCuboid is deprecated, please use lattice.to['cuboid'](...) instead.", + "0.15.0", + ) def toCuboid(self, *args, **kwargs): - """ A cuboid with vectors as this unit-cell and center with respect to its origin + """A cuboid with vectors as this unit-cell and center with respect to its origin Parameters ---------- @@ -189,12 +198,14 @@ def toCuboid(self, *args, **kwargs): """ return self.to[Cuboid](*args, **kwargs) - def set_boundary_condition(self, - boundary: Optional[SeqBoundaryConditionType] =None, - a: Optional[SeqBoundaryConditionType] =None, - b: Optional[SeqBoundaryConditionType] =None, - c: Optional[SeqBoundaryConditionType] =None): - """ Set the boundary conditions on the grid + def set_boundary_condition( + self, + boundary: Optional[SeqBoundaryConditionType] = None, + a: Optional[SeqBoundaryConditionType] = None, + b: Optional[SeqBoundaryConditionType] = None, + c: Optional[SeqBoundaryConditionType] = None, + ): + """Set the boundary conditions on the grid Parameters ---------- @@ -213,6 +224,7 @@ def set_boundary_condition(self, if specifying periodic one one boundary, so must the opposite side. """ getitem = BoundaryCondition.getitem + def conv(v): if v is None: return v @@ -247,18 +259,25 @@ def conv(v): self._bc[d, :] = v # shorthand for bc - for nsc, bc, changed in zip(self.nsc, self._bc == BoundaryCondition.PERIODIC, self._bc != old): + for nsc, bc, changed in zip( + self.nsc, self._bc == BoundaryCondition.PERIODIC, self._bc != old + ): if bc.any() and not bc.all(): - raise ValueError(f"{self.__class__.__name__}.set_boundary_condition has a one non-periodic and " - "one periodic direction. If one direction is periodic, both instances " - "must have that BC.") + raise ValueError( + f"{self.__class__.__name__}.set_boundary_condition has a one non-periodic and " + "one periodic direction. If one direction is periodic, both instances " + "must have that BC." + ) if changed.any() and (~bc).all() and nsc > 1: - info(f"{self.__class__.__name__}.set_boundary_condition is having image connections (nsc={nsc}>1) " - "while having a non-periodic boundary condition.") - + info( + f"{self.__class__.__name__}.set_boundary_condition is having image connections (nsc={nsc}>1) " + "while having a non-periodic boundary condition." + ) - def parameters(self, rad: bool=False) -> Tuple[float, float, float, float, float, float]: - r""" Cell parameters of this cell in 3 lengths and 3 angles + def parameters( + self, rad: bool = False + ) -> Tuple[float, float, float, float, float, float]: + r"""Cell parameters of this cell in 3 lengths and 3 angles Notes ----- @@ -287,7 +306,7 @@ def parameters(self, rad: bool=False) -> Tuple[float, float, float, float, float angle between a and b vectors """ if rad: - f = 1. + f = 1.0 else: f = 180 / np.pi @@ -296,6 +315,7 @@ def parameters(self, rad: bool=False) -> Tuple[float, float, float, float, float abc = fnorm(cell) from math import acos + cell = cell / abc.reshape(-1, 1) alpha = acos(dot3(cell[1, :], cell[2, :])) * f beta = acos(dot3(cell[0, :], cell[2, :])) * f @@ -304,7 +324,7 @@ def parameters(self, rad: bool=False) -> Tuple[float, float, float, float, float return abc[0], abc[1], abc[2], alpha, beta, gamma def _fill(self, non_filled, dtype=None): - """ Return a zero filled array of length 3 """ + """Return a zero filled array of length 3""" if len(non_filled) == 3: return non_filled @@ -335,11 +355,11 @@ def _fill(self, non_filled, dtype=None): return f def _fill_sc(self, supercell_index): - """ Return a filled supercell index by filling in zeros where needed """ + """Return a filled supercell index by filling in zeros where needed""" return self._fill(supercell_index, dtype=np.int32) def set_nsc(self, nsc=None, a=None, b=None, c=None): - """ Sets the number of supercells in the 3 different cell directions + """Sets the number of supercells in the 3 different cell directions Parameters ---------- @@ -368,8 +388,9 @@ def set_nsc(self, nsc=None, a=None, b=None, c=None): self.nsc[i] = 1 if np.sum(self.nsc % 2) != 3: raise ValueError( - "Supercells has to be of un-even size. The primary cell counts " + - "one, all others count 2") + "Supercells has to be of un-even size. The primary cell counts " + + "one, all others count 2" + ) # We might use this very often, hence we store it self.n_s = _a.prodi(self.nsc) @@ -381,7 +402,8 @@ def set_nsc(self, nsc=None, a=None, b=None, c=None): def ret_range(val): i = val // 2 - return range(-i, i+1) + return range(-i, i + 1) + x = ret_range(n[0]) y = ret_range(n[1]) z = ret_range(n[2]) @@ -402,33 +424,33 @@ def ret_range(val): self._update_isc_off() def _update_isc_off(self): - """ Internal routine for updating the supercell indices """ + """Internal routine for updating the supercell indices""" for i in range(self.n_s): d = self.sc_off[i, :] self._isc_off[d[0], d[1], d[2]] = i @property def sc_off(self) -> ndarray: - """ Integer supercell offsets """ + """Integer supercell offsets""" return self._sc_off @sc_off.setter def sc_off(self, sc_off): - """ Set the supercell offset """ - self._sc_off[:, :] = _a.arrayi(sc_off, order='C') + """Set the supercell offset""" + self._sc_off[:, :] = _a.arrayi(sc_off, order="C") self._update_isc_off() @property def isc_off(self) -> ndarray: - """ Internal indexed supercell ``[ia, ib, ic] == i`` """ + """Internal indexed supercell ``[ia, ib, ic] == i``""" return self._isc_off def __iter__(self): - """ Iterate the supercells and the indices of the supercells """ + """Iterate the supercells and the indices of the supercells""" yield from enumerate(self.sc_off) def copy(self, cell=None, origin=None): - """ A deepcopy of the object + """A deepcopy of the object Parameters ---------- @@ -456,7 +478,7 @@ def copy(self, cell=None, origin=None): return copy def fit(self, xyz, axis=None, tol=0.05): - """ Fit the supercell to `xyz` such that the unit-cell becomes periodic in the specified directions + """Fit the supercell to `xyz` such that the unit-cell becomes periodic in the specified directions The fitted supercell tries to determine the unit-cell parameters by solving a set of linear equations corresponding to the current supercell vectors. @@ -499,8 +521,10 @@ def fit(self, xyz, axis=None, tol=0.05): dist = np.sqrt((dot(cell.T, (x - ix).T) ** 2).sum(0)) idx = (dist <= tol).nonzero()[0] if len(idx) == 0: - raise ValueError('Could not fit the cell parameters to the coordinates ' - 'due to insufficient accuracy (try increase the tolerance)') + raise ValueError( + "Could not fit the cell parameters to the coordinates " + "due to insufficient accuracy (try increase the tolerance)" + ) # Reduce problem to allowed values below the tolerance ix = ix[idx, :] @@ -525,10 +549,10 @@ def fit(self, xyz, axis=None, tol=0.05): return self.copy(cell) - def swapaxes(self, axes_a: Union[int, str], - axes_b: Union[int, str], - what: str="abc") -> Lattice: - r""" Swaps axes `axes_a` and `axes_b` + def swapaxes( + self, axes_a: Union[int, str], axes_b: Union[int, str], what: str = "abc" + ) -> Lattice: + r"""Swaps axes `axes_a` and `axes_b` Swapaxes is a versatile method for changing the order of axes elements, either lattice vector order, or Cartesian @@ -586,10 +610,14 @@ def swapaxes(self, axes_a: Union[int, str], axes_a = "xyz"[axes_a] axes_b = "xyz"[axes_b] else: - raise ValueError(f"{self.__class__.__name__}.swapaxes could not understand 'what' " - "must contain abc and/or xyz.") + raise ValueError( + f"{self.__class__.__name__}.swapaxes could not understand 'what' " + "must contain abc and/or xyz." + ) elif (not isinstance(axes_a, str)) or (not isinstance(axes_b, str)): - raise ValueError(f"{self.__class__.__name__}.swapaxes axes arguments must be either all int or all str, not a mix.") + raise ValueError( + f"{self.__class__.__name__}.swapaxes axes arguments must be either all int or all str, not a mix." + ) cell = self.cell nsc = self.nsc @@ -597,7 +625,9 @@ def swapaxes(self, axes_a: Union[int, str], bc = self.boundary_condition if len(axes_a) != len(axes_b): - raise ValueError(f"{self.__class__.__name__}.swapaxes expects axes_a and axes_b to have the same lengeth {len(axes_a)}, {len(axes_b)}.") + raise ValueError( + f"{self.__class__.__name__}.swapaxes expects axes_a and axes_b to have the same lengeth {len(axes_a)}, {len(axes_b)}." + ) for a, b in zip(axes_a, axes_b): idx = [0, 1, 2] @@ -606,7 +636,9 @@ def swapaxes(self, axes_a: Union[int, str], bidx = "abcxyz".index(b) if aidx // 3 != bidx // 3: - raise ValueError(f"{self.__class__.__name__}.swapaxes expects axes_a and axes_b to belong to the same category, do not mix lattice vector swaps with Cartesian coordinates.") + raise ValueError( + f"{self.__class__.__name__}.swapaxes expects axes_a and axes_b to belong to the same category, do not mix lattice vector swaps with Cartesian coordinates." + ) if aidx < 3: idx[aidx], idx[bidx] = idx[bidx], idx[aidx] @@ -625,13 +657,12 @@ def swapaxes(self, axes_a: Union[int, str], origin = origin[idx] bc = bc[idx] - return self.__class__(cell.copy(), - nsc=nsc.copy(), - origin=origin.copy(), - boundary_condition=bc) + return self.__class__( + cell.copy(), nsc=nsc.copy(), origin=origin.copy(), boundary_condition=bc + ) def plane(self, ax1, ax2, origin=True): - """ Query point and plane-normal for the plane spanning `ax1` and `ax2` + """Query point and plane-normal for the plane spanning `ax1` and `ax2` Parameters ---------- @@ -694,7 +725,7 @@ def plane(self, ax1, ax2, origin=True): # If d is positive then the normal vector is pointing towards # the center, so rotate 180 - if dot3(n, up / 2) > 0.: + if dot3(n, up / 2) > 0.0: n *= -1 if origin: @@ -703,7 +734,7 @@ def plane(self, ax1, ax2, origin=True): return -n, up def __mul__(self, m): - """ Implement easy repeat function + """Implement easy repeat function Parameters ---------- @@ -727,7 +758,7 @@ def __mul__(self, m): @property def icell(self): - """ Returns the reciprocal (inverse) cell for the `Lattice`. + """Returns the reciprocal (inverse) cell for the `Lattice`. Note: The returned vectors are still in ``[0, :]`` format and not as returned by an inverse LAPACK algorithm. @@ -736,7 +767,7 @@ def icell(self): @property def rcell(self): - """ Returns the reciprocal cell for the `Lattice` with ``2*np.pi`` + """Returns the reciprocal cell for the `Lattice` with ``2*np.pi`` Note: The returned vectors are still in [0, :] format and not as returned by an inverse LAPACK algorithm. @@ -744,7 +775,7 @@ def rcell(self): return cell_reciprocal(self.cell) def cell2length(self, length, axes=(0, 1, 2)) -> ndarray: - """ Calculate cell vectors such that they each have length `length` + """Calculate cell vectors such that they each have length `length` Parameters ---------- @@ -770,17 +801,20 @@ def cell2length(self, length, axes=(0, 1, 2)) -> ndarray: if len(length) == 1: length = np.tile(length, len(axes)) else: - raise ValueError(f"{self.__class__.__name__}.cell2length length parameter should be a single " - "float, or an array of values according to axes argument.") + raise ValueError( + f"{self.__class__.__name__}.cell2length length parameter should be a single " + "float, or an array of values according to axes argument." + ) return self.cell[axes] * (length / self.length[axes]).reshape(-1, 1) - @deprecate_argument("only", "what", - "argument only has been deprecated in favor of what, please update your code.", - "0.14.0") - def rotate(self, angle, v, - rad: bool=False, - what: str="abc") -> Lattice: - """ Rotates the supercell, in-place by the angle around the vector + @deprecate_argument( + "only", + "what", + "argument only has been deprecated in favor of what, please update your code.", + "0.14.0", + ) + def rotate(self, angle, v, rad: bool = False, what: str = "abc") -> Lattice: + """Rotates the supercell, in-place by the angle around the vector One can control which cell vectors are rotated by designating them individually with ``only='[abc]'``. @@ -800,7 +834,11 @@ def rotate(self, angle, v, if isinstance(v, Integral): v = direction(v, abc=self.cell, xyz=np.diag([1, 1, 1])) elif isinstance(v, str): - v = reduce(lambda a, b: a + direction(b, abc=self.cell, xyz=np.diag([1, 1, 1])), v, 0) + v = reduce( + lambda a, b: a + direction(b, abc=self.cell, xyz=np.diag([1, 1, 1])), + v, + 0, + ) # flatten => copy vn = _a.asarrayd(v).flatten() vn /= fnorm(vn) @@ -808,7 +846,7 @@ def rotate(self, angle, v, q /= q.norm() # normalize the quaternion cell = np.copy(self.cell) idx = [] - for i, d in enumerate('abc'): + for i, d in enumerate("abc"): if d in what: idx.append(i) if idx: @@ -816,13 +854,13 @@ def rotate(self, angle, v, return self.copy(cell) def offset(self, isc=None): - """ Returns the supercell offset of the supercell index """ + """Returns the supercell offset of the supercell index""" if isc is None: return _a.arrayd([0, 0, 0]) return dot(isc, self.cell) def add(self, other): - """ Add two supercell lattice vectors to each other + """Add two supercell lattice vectors to each other Parameters ---------- @@ -842,7 +880,7 @@ def __add__(self, other): __radd__ = __add__ def add_vacuum(self, vacuum, axis, orthogonal_to_plane=False): - """ Add vacuum along the `axis` lattice vector + """Add vacuum along the `axis` lattice vector Parameters ---------- @@ -859,7 +897,7 @@ def add_vacuum(self, vacuum, axis, orthogonal_to_plane=False): d /= fnorm(d) if orthogonal_to_plane: # first calculate the normal vector of the other plane - n = cross3(cell[axis-1], cell[axis-2]) + n = cross3(cell[axis - 1], cell[axis - 2]) n /= fnorm(n) # now project onto cell projection = n @ d @@ -874,7 +912,7 @@ def add_vacuum(self, vacuum, axis, orthogonal_to_plane=False): return self.copy(cell) def sc_index(self, sc_off): - """ Returns the integer index in the sc_off list that corresponds to `sc_off` + """Returns the integer index in the sc_off list that corresponds to `sc_off` Returns the index for the supercell in the global offset. @@ -884,9 +922,11 @@ def sc_index(self, sc_off): super cell specification. For each axis having value ``None`` all supercells along that axis is returned. """ + def _assert(m, v): if np.any(np.abs(v) > m): raise ValueError("Requesting a non-existing supercell index") + hsc = self.nsc // 2 if len(sc_off) == 0: @@ -944,7 +984,7 @@ def vertices(self): return verts @ self.cell def scale(self, scale, what="abc"): - """ Scale lattice vectors + """Scale lattice vectors Does not scale `origin`. @@ -961,10 +1001,12 @@ def scale(self, scale, what="abc"): return self.copy((self.cell.T * scale).T) if what == "xyz": return self.copy(self.cell * scale) - raise ValueError(f"{self.__class__.__name__}.scale argument what='{what}' is not in ['abc', 'xyz'].") + raise ValueError( + f"{self.__class__.__name__}.scale argument what='{what}' is not in ['abc', 'xyz']." + ) def tile(self, reps, axis): - """ Extend the unit-cell `reps` times along the `axis` lattice vector + """Extend the unit-cell `reps` times along the `axis` lattice vector Notes ----- @@ -990,7 +1032,7 @@ def tile(self, reps, axis): return self.__class__(cell, nsc=nsc, origin=origin) def repeat(self, reps, axis): - """ Extend the unit-cell `reps` times along the `axis` lattice vector + """Extend the unit-cell `reps` times along the `axis` lattice vector Notes ----- @@ -1024,21 +1066,21 @@ def untile(self, reps, axis): unrepeat = untile def append(self, other, axis): - """ Appends other `Lattice` to this grid along axis """ + """Appends other `Lattice` to this grid along axis""" cell = np.copy(self.cell) cell[axis, :] += other.cell[axis, :] # TODO fix nsc here return self.copy(cell) def prepend(self, other, axis): - """ Prepends other `Lattice` to this grid along axis + """Prepends other `Lattice` to this grid along axis For a `Lattice` object this is equivalent to `append`. """ return other.append(self, axis) def translate(self, v): - """ Appends additional space to the object """ + """Appends additional space to the object""" # check which cell vector resembles v the most, # use that cell = np.copy(self.cell) @@ -1048,17 +1090,18 @@ def translate(self, v): p[i] = abs(np.sum(cell[i, :] * v)) / cl[i] cell[np.argmax(p), :] += v return self.copy(cell) + move = translate def center(self, axis=None): - """ Returns center of the `Lattice`, possibly with respect to an axis """ + """Returns center of the `Lattice`, possibly with respect to an axis""" if axis is None: return self.cell.sum(0) * 0.5 return self.cell[axis, :] * 0.5 @classmethod def tocell(cls, *args): - r""" Returns a 3x3 unit-cell dependent on the input + r"""Returns a 3x3 unit-cell dependent on the input 1 argument a unit-cell along Cartesian coordinates with side-length @@ -1111,7 +1154,8 @@ def tocell(cls, *args): gamma = args[5] from math import cos, pi, sin, sqrt - pi180 = pi / 180. + + pi180 = pi / 180.0 cell[0, 0] = a g = gamma * pi180 @@ -1126,7 +1170,7 @@ def tocell(cls, *args): a = alpha * pi180 d = (cos(a) - cb * cg) / sg cell[2, 1] = c * d - cell[2, 2] = c * sqrt(sb ** 2 - d ** 2) + cell[2, 2] = c * sqrt(sb**2 - d**2) return cell # A complete cell @@ -1134,7 +1178,8 @@ def tocell(cls, *args): return args.copy().reshape(3, 3) raise ValueError( - "Creating a unit-cell has to have 1, 3 or 6 arguments, please correct.") + "Creating a unit-cell has to have 1, 3 or 6 arguments, please correct." + ) def is_orthogonal(self, tol=0.001): """ @@ -1171,7 +1216,7 @@ def is_cartesian(self, tol=0.001): return ~np.any(np.abs(off_diagonal) > tol) def parallel(self, other, axis=(0, 1, 2)): - """ Returns true if the cell vectors are parallel to `other` + """Returns true if the cell vectors are parallel to `other` Parameters ---------- @@ -1190,7 +1235,7 @@ def parallel(self, other, axis=(0, 1, 2)): return True def angle(self, i, j, rad=False): - """ The angle between two of the cell vectors + """The angle between two of the cell vectors Parameters ---------- @@ -1209,7 +1254,7 @@ def angle(self, i, j, rad=False): @staticmethod def read(sile, *args, **kwargs): - """ Reads the supercell from the `Sile` using ``Sile.read_lattice`` + """Reads the supercell from the `Sile` using ``Sile.read_lattice`` Parameters ---------- @@ -1220,14 +1265,15 @@ def read(sile, *args, **kwargs): # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_lattice(*args, **kwargs) else: - with get_sile(sile, mode='r') as fh: + with get_sile(sile, mode="r") as fh: return fh.read_lattice(*args, **kwargs) def equal(self, other, tol=1e-4): - """ Check whether two lattices are equivalent + """Check whether two lattices are equivalent Parameters ---------- @@ -1242,7 +1288,8 @@ def equal(self, other, tol=1e-4): return same def __str__(self): - """ Returns a string representation of the object """ + """Returns a string representation of the object""" + # Create format for lattice vectors def bcstr(bc): left = BoundaryCondition.getitem(bc[0]).name.capitalize() @@ -1251,7 +1298,13 @@ def bcstr(bc): return left right = BoundaryCondition.getitem(bc[1]).name.capitalize() return f"[{left}, {right}]" - s = ',\n '.join(['ABC'[i] + '=[{:.4f}, {:.4f}, {:.4f}]'.format(*self.cell[i]) for i in (0, 1, 2)]) + + s = ",\n ".join( + [ + "ABC"[i] + "=[{:.4f}, {:.4f}, {:.4f}]".format(*self.cell[i]) + for i in (0, 1, 2) + ] + ) origin = "{:.4f}, {:.4f}, {:.4f}".format(*self.origin) bc = ",\n ".join(map(bcstr, self.boundary_condition)) return f"{self.__class__.__name__}{{nsc: {self.nsc},\n origin={origin},\n {s},\n bc=[{bc}]\n}}" @@ -1260,6 +1313,7 @@ def __repr__(self): a, b, c, alpha, beta, gamma = map(lambda r: round(r, 4), self.parameters()) BC = BoundaryCondition bc = self.boundary_condition + def bcstr(bc): left = BC.getitem(bc[0]).name[0] if bc[0] == bc[1]: @@ -1267,29 +1321,35 @@ def bcstr(bc): return left right = BC.getitem(bc[1]).name[0] return f"[{left}, {right}]" + bc = ", ".join(map(bcstr, self.boundary_condition)) return f"<{self.__module__}.{self.__class__.__name__} a={a}, b={b}, c={c}, α={alpha}, β={beta}, γ={gamma}, bc=[{bc}], nsc={self.nsc}>" def __eq__(self, other): - """ Equality check """ + """Equality check""" return self.equal(other) def __ne__(self, b): - """ In-equality check """ + """In-equality check""" return not (self == b) # Create pickling routines def __getstate__(self): - """ Returns the state of this object """ - return {'cell': self.cell, 'nsc': self.nsc, 'sc_off': self.sc_off, 'origin': self.origin} + """Returns the state of this object""" + return { + "cell": self.cell, + "nsc": self.nsc, + "sc_off": self.sc_off, + "origin": self.origin, + } def __setstate__(self, d): - """ Re-create the state of this object """ - self.__init__(d['cell'], d['nsc'], d['origin']) - self.sc_off = d['sc_off'] + """Re-create the state of this object""" + self.__init__(d["cell"], d["nsc"], d["origin"]) + self.sc_off = d["sc_off"] def __plot__(self, axis=None, axes=False, *args, **kwargs): - """ Plot the supercell in a specified ``matplotlib.Axes`` object. + """Plot the supercell in a specified ``matplotlib.Axes`` object. Parameters ---------- @@ -1304,17 +1364,17 @@ def __plot__(self, axis=None, axes=False, *args, **kwargs): d = dict() # Try and default the color and alpha - if 'color' not in kwargs and len(args) == 0: - kwargs['color'] = 'k' - if 'alpha' not in kwargs: - kwargs['alpha'] = 0.5 + if "color" not in kwargs and len(args) == 0: + kwargs["color"] = "k" + if "alpha" not in kwargs: + kwargs["alpha"] = 0.5 if axis is None: axis = [0, 1, 2] # Ensure we have a new 3D Axes3D if len(axis) == 3: - d['projection'] = '3d' + d["projection"] = "3d" axes = plt.get_axes(axes, **d) @@ -1325,15 +1385,21 @@ def __plot__(self, axis=None, axes=False, *args, **kwargs): v.append(np.vstack((o[axis], o[axis] + self.cell[a, axis]))) v = np.array(v) - if axes.__class__.__name__.startswith('Axes3D'): + if axes.__class__.__name__.startswith("Axes3D"): # We should plot in 3D plots for vv in v: axes.plot(vv[:, 0], vv[:, 1], vv[:, 2], *args, **kwargs) v0, v1 = v[0], v[1] - o - axes.plot(v0[1, 0] + v1[:, 0], v0[1, 1] + v1[:, 1], v0[1, 2] + v1[:, 2], *args, **kwargs) + axes.plot( + v0[1, 0] + v1[:, 0], + v0[1, 1] + v1[:, 1], + v0[1, 2] + v1[:, 2], + *args, + **kwargs, + ) - axes.set_zlabel('Ang') + axes.set_zlabel("Ang") else: for vv in v: @@ -1343,17 +1409,19 @@ def __plot__(self, axis=None, axes=False, *args, **kwargs): axes.plot(v0[1, 0] + v1[:, 0], v0[1, 1] + v1[:, 1], *args, **kwargs) axes.plot(v1[1, 0] + v0[:, 0], v1[1, 1] + v0[:, 1], *args, **kwargs) - axes.set_xlabel('Ang') - axes.set_ylabel('Ang') + axes.set_xlabel("Ang") + axes.set_ylabel("Ang") return axes + new_dispatch = Lattice.new to_dispatch = Lattice.to + # Define base-class for this class LatticeNewDispatch(AbstractDispatch): - """ Base dispatcher from class passing arguments to Geometry class + """Base dispatcher from class passing arguments to Geometry class This forwards all `__call__` calls to `dispatch` """ @@ -1361,63 +1429,80 @@ class LatticeNewDispatch(AbstractDispatch): def __call__(self, *args, **kwargs): return self.dispatch(*args, **kwargs) + class LatticeNewLatticeDispatch(LatticeNewDispatch): def dispatch(self, lattice, copy=False): # for sanitation purposes if copy: return lattice.copy() return lattice + + new_dispatch.register(Lattice, LatticeNewLatticeDispatch) + class LatticeNewAseDispatch(LatticeNewDispatch): def dispatch(self, aseg): cls = self._get_class(allow_instance=True) cell = aseg.get_cell() nsc = [3 if pbc else 1 for pbc in aseg.pbc] return cls(cell, nsc=nsc) + + new_dispatch.register("ase", LatticeNewAseDispatch) # currently we can't ensure the ase Atoms type # to get it by type(). That requires ase to be importable. try: from ase import Cell as ase_Cell + new_dispatch.register(ase_Cell, LatticeNewAseDispatch) # ensure we don't pollute name-space del ase_Cell except Exception: pass + class LatticeNewFileDispatch(LatticeNewDispatch): def dispatch(self, *args, **kwargs): - """ Defer the `Lattice.read` method by passing down arguments """ + """Defer the `Lattice.read` method by passing down arguments""" # can work either on class or instance return self._obj.read(*args, **kwargs) + + new_dispatch.register(str, LatticeNewFileDispatch) new_dispatch.register(Path, LatticeNewFileDispatch) # see sisl/__init__.py for new_dispatch.register(BaseSile, ...) class LatticeToDispatch(AbstractDispatch): - """ Base dispatcher from class passing from Lattice class """ + """Base dispatcher from class passing from Lattice class""" + class LatticeToAseDispatch(LatticeToDispatch): def dispatch(self, **kwargs): from ase import Cell as ase_Cell + lattice = self._get_object() return ase_Cell(lattice.cell.copy()) + to_dispatch.register("ase", LatticeToAseDispatch) + class LatticeToSileDispatch(LatticeToDispatch): def dispatch(self, *args, **kwargs): lattice = self._get_object() return lattice.write(*args, **kwargs) + + to_dispatch.register("str", LatticeToSileDispatch) to_dispatch.register("Path", LatticeToSileDispatch) # to do geom.to[Path](path) to_dispatch.register(str, LatticeToSileDispatch) to_dispatch.register(Path, LatticeToSileDispatch) + class LatticeToCuboidDispatch(LatticeToDispatch): def dispatch(self, center=None, origin=None, orthogonal=False): lattice = self._get_object() @@ -1441,6 +1526,7 @@ def find_min_max(cmin, cmax, new): for i in range(3): cmin[i] = min(cmin[i], new[i]) cmax[i] = max(cmax[i], new[i]) + cmin = cell.min(0) cmax = cell.max(0) find_min_max(cmin, cmax, cell[[0, 1], :].sum(0)) @@ -1449,6 +1535,7 @@ def find_min_max(cmin, cmax, new): find_min_max(cmin, cmax, cell.sum(0)) return Cuboid(cmax - cmin, center_off) + to_dispatch.register("Cuboid", LatticeToCuboidDispatch) to_dispatch.register(Cuboid, LatticeToCuboidDispatch) @@ -1458,14 +1545,18 @@ def find_min_max(cmin, cmax, new): class SuperCell(Lattice): - """ Deprecated class, please use `Lattice` instead """ + """Deprecated class, please use `Lattice` instead""" + def __init__(self, *args, **kwargs): - deprecate(f"{self.__class__.__name__} is deprecated; please use 'Lattice' class instead", "0.15") + deprecate( + f"{self.__class__.__name__} is deprecated; please use 'Lattice' class instead", + "0.15", + ) super().__init__(*args, **kwargs) class LatticeChild: - """ Class to be inherited by using the ``self.lattice`` as a `Lattice` object + """Class to be inherited by using the ``self.lattice`` as a `Lattice` object Initialize by a `Lattice` object and get access to several different routines directly related to the `Lattice` class. @@ -1474,11 +1565,14 @@ class LatticeChild: @property def sc(self): """[deprecated] Return the lattice object associated with the `Lattice`.""" - deprecate(f"{self.__class__.__name__}.sc is deprecated; please use 'lattice' instead", "0.15") + deprecate( + f"{self.__class__.__name__}.sc is deprecated; please use 'lattice' instead", + "0.15", + ) return self.lattice def set_nsc(self, *args, **kwargs): - """ Set the number of super-cells in the `Lattice` object + """Set the number of super-cells in the `Lattice` object See `set_nsc` for allowed parameters. @@ -1493,7 +1587,7 @@ def set_lattice(self, lattice): if lattice is None: # Default supercell is a simple # 1x1x1 unit-cell - self.lattice = Lattice([1., 1., 1.]) + self.lattice = Lattice([1.0, 1.0, 1.0]) elif isinstance(lattice, Lattice): self.lattice = lattice elif isinstance(lattice, LatticeChild): @@ -1502,8 +1596,12 @@ def set_lattice(self, lattice): # The supercell is given as a cell self.lattice = Lattice(lattice) - set_sc = deprecation("set_sc is deprecated; please use set_lattice instead", "0.14")(set_lattice) - set_supercell = deprecation("set_sc is deprecated; please use set_lattice instead", "0.15")(set_lattice) + set_sc = deprecation( + "set_sc is deprecated; please use set_lattice instead", "0.14" + )(set_lattice) + set_supercell = deprecation( + "set_sc is deprecated; please use set_lattice instead", "0.15" + )(set_lattice) @property def length(self): @@ -1512,58 +1610,57 @@ def length(self): @property def volume(self): - """ Returns the inherent `Lattice` objects `volume` """ + """Returns the inherent `Lattice` objects `volume`""" return self.lattice.volume def area(self, ax0, ax1): - """ Calculate the area spanned by the two axis `ax0` and `ax1` """ + """Calculate the area spanned by the two axis `ax0` and `ax1`""" return self.lattice.area(ax0, ax1) @property def cell(self): - """ Returns the inherent `Lattice` objects `cell` """ + """Returns the inherent `Lattice` objects `cell`""" return self.lattice.cell @property def icell(self): - """ Returns the inherent `Lattice` objects `icell` """ + """Returns the inherent `Lattice` objects `icell`""" return self.lattice.icell @property def rcell(self): - """ Returns the inherent `Lattice` objects `rcell` """ + """Returns the inherent `Lattice` objects `rcell`""" return self.lattice.rcell @property def origin(self): - """ Returns the inherent `Lattice` objects `origin` """ + """Returns the inherent `Lattice` objects `origin`""" return self.lattice.origin @property def n_s(self): - """ Returns the inherent `Lattice` objects `n_s` """ + """Returns the inherent `Lattice` objects `n_s`""" return self.lattice.n_s @property def nsc(self): - """ Returns the inherent `Lattice` objects `nsc` """ + """Returns the inherent `Lattice` objects `nsc`""" return self.lattice.nsc @property def sc_off(self): - """ Returns the inherent `Lattice` objects `sc_off` """ + """Returns the inherent `Lattice` objects `sc_off`""" return self.lattice.sc_off @property def isc_off(self): - """ Returns the inherent `Lattice` objects `isc_off` """ + """Returns the inherent `Lattice` objects `isc_off`""" return self.lattice.isc_off def sc_index(self, *args, **kwargs): - """ Call local `Lattice` object `sc_index` function """ + """Call local `Lattice` object `sc_index` function""" return self.lattice.sc_index(*args, **kwargs) - @property def boundary_condition(self) -> np.ndarray: f"""{Lattice.boundary_condition.__doc__}""" @@ -1572,10 +1669,11 @@ def boundary_condition(self) -> np.ndarray: @boundary_condition.setter def boundary_condition(self, boundary_condition: Sequence[BoundaryConditionType]): f"""{Lattice.boundary_condition.__doc__}""" - raise SislError(f"Cannot use property to set boundary conditions of LatticeChild") + raise SislError( + f"Cannot use property to set boundary conditions of LatticeChild" + ) @property def pbc(self) -> np.ndarray: f"""{Lattice.pbc.__doc__}""" return self.lattice.pbc - diff --git a/src/sisl/linalg/base.py b/src/sisl/linalg/base.py index ffe721ac73..523893b126 100644 --- a/src/sisl/linalg/base.py +++ b/src/sisl/linalg/base.py @@ -36,14 +36,14 @@ def _datacopied(arr, original): # I.e. when fetching the same method over and over # we should be able to reduce the overhead by retrieving the intrinsic version. _linalg_info_dtype = { - np.float32: 'f4', - np.float64: 'f8', - np.complex64: 'c8', - np.complex128: 'c16', - 'f4': 'f4', - 'f8': 'f8', - 'c8': 'c8', - 'c16': 'c16', + np.float32: "f4", + np.float64: "f8", + np.complex64: "c8", + np.complex128: "c16", + "f4": "f4", + "f8": "f8", + "c8": "c8", + "c16": "c16", } _linalg_info_base = {} # Initialize the base-dtype dicts @@ -52,8 +52,10 @@ def _datacopied(arr, original): @set_module("sisl.linalg") -def linalg_info(method, dtype, method_dict=_linalg_info_base, dtype_dict=_linalg_info_dtype): - """ Faster BLAS/LAPACK methods to be returned without too many lookups an array checks +def linalg_info( + method, dtype, method_dict=_linalg_info_base, dtype_dict=_linalg_info_dtype +): + """Faster BLAS/LAPACK methods to be returned without too many lookups an array checks Parameters ---------- @@ -85,7 +87,7 @@ def linalg_info(method, dtype, method_dict=_linalg_info_base, dtype_dict=_linalg try: func = get_lapack_funcs(method, dtype=dtype) except ValueError as e: - if 'LAPACK function' in str(e): + if "LAPACK function" in str(e): func = get_blas_funcs(method, dtype=dtype) else: raise e @@ -94,18 +96,17 @@ def linalg_info(method, dtype, method_dict=_linalg_info_base, dtype_dict=_linalg def _compute_lwork(routine, *args, **kwargs): - """ See scipy.linalg.lapack._compute_lwork """ + """See scipy.linalg.lapack._compute_lwork""" wi = routine(*args, **kwargs) if len(wi) < 2: - raise ValueError('') + raise ValueError("") info = wi[-1] if info != 0: - raise ValueError("Internal work array size computation failed: " - "%d" % (info,)) + raise ValueError("Internal work array size computation failed: " "%d" % (info,)) lwork = [w.real for w in wi[:-1]] - dtype = getattr(routine, 'dtype', None) + dtype = getattr(routine, "dtype", None) if dtype == np.float32 or dtype == np.complex64: # Single-precision routine -- take next fp value to work # around possible truncation in LAPACK code @@ -113,8 +114,10 @@ def _compute_lwork(routine, *args, **kwargs): lwork = np.array(lwork, np.int64) if np.any(np.logical_or(lwork < 0, lwork > np.iinfo(np.int32).max)): - raise ValueError("Too large work array required -- computation cannot " - "be performed with standard 32-bit LAPACK.") + raise ValueError( + "Too large work array required -- computation cannot " + "be performed with standard 32-bit LAPACK." + ) lwork = lwork.astype(np.int32) if lwork.size == 1: return lwork[0] @@ -143,10 +146,11 @@ def inv(a, overwrite_a=False): overwrite_a = overwrite_a or _datacopied(a1, a) if a1.shape[0] != a1.shape[1]: - raise ValueError('Input a needs to be a square matrix.') + raise ValueError("Input a needs to be a square matrix.") - getrf, getri, getri_lwork = get_lapack_funcs(('getrf', 'getri', - 'getri_lwork'), (a1,)) + getrf, getri, getri_lwork = get_lapack_funcs( + ("getrf", "getri", "getri_lwork"), (a1,) + ) lu, piv, info = getrf(a1, overwrite_a=overwrite_a) if info == 0: lwork = _compute_lwork(getri_lwork, a1.shape[0]) @@ -155,8 +159,9 @@ def inv(a, overwrite_a=False): if info > 0: raise LinAlgError("Singular matrix") elif info < 0: - raise ValueError('illegal value in %d-th argument of internal ' - 'getrf|getri' % -info) + raise ValueError( + "illegal value in %d-th argument of internal " "getrf|getri" % -info + ) return x @@ -189,13 +194,12 @@ def solve(a, b, overwrite_a=False, overwrite_b=False, assume_a="gen"): overwrite_b = overwrite_b or _datacopied(b1, b) if a1.shape[0] != a1.shape[1]: - raise ValueError('LHS needs to be a square matrix.') + raise ValueError("LHS needs to be a square matrix.") if n != b1.shape[0]: # Last chance to catch 1x1 scalar a and 1D b arrays if not (n == 1 and b1.size != 0): - raise ValueError('Input b has to have same number of rows as ' - 'input a') + raise ValueError("Input b has to have same number of rows as " "input a") # regularize 1D b arrays to 2D if b1.ndim == 1: @@ -209,23 +213,25 @@ def solve(a, b, overwrite_a=False, overwrite_b=False, assume_a="gen"): if assume_a == "sym": lower = False - sysv, sysv_lw = get_lapack_funcs(("sysv", - "sysv_lwork"), (a1, b1)) + sysv, sysv_lw = get_lapack_funcs(("sysv", "sysv_lwork"), (a1, b1)) lwork = _compute_lwork(sysv_lw, n, lower) - _, _, x, info = sysv(a1, b1, lwork=lwork, - lower=lower, - overwrite_a=overwrite_a, - overwrite_b=overwrite_b) + _, _, x, info = sysv( + a1, + b1, + lwork=lwork, + lower=lower, + overwrite_a=overwrite_a, + overwrite_b=overwrite_b, + ) elif assume_a == "gen": - gesv = get_lapack_funcs('gesv', (a1, b1)) + gesv = get_lapack_funcs("gesv", (a1, b1)) _, _, x, info = gesv(a1, b1, overwrite_a=overwrite_a, overwrite_b=overwrite_b) else: raise ValueError("Input assume_a is not one of gen/sym") if info > 0: raise LinAlgError("Singular matrix") elif info < 0: - raise ValueError('illegal value in %d-th argument of internal ' - 'gesv' % -info) + raise ValueError("illegal value in %d-th argument of internal " "gesv" % -info) if b_is_1D: return x.ravel() @@ -235,6 +241,7 @@ def solve(a, b, overwrite_a=False, overwrite_b=False, assume_a="gen"): def _append(name, suffix): return [name + s for s in suffix] + # Solving a linear system solve_destroy = _partial(solve, overwrite_a=True, overwrite_b=True) __all__ += _append("solve", ["", "_destroy"]) @@ -250,13 +257,21 @@ def _append(name, suffix): # Solve eigenvalue problem eig = _partial(sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False) -eig_left = _partial(sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False, left=True) -eig_right = _partial(sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False, right=True) +eig_left = _partial( + sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False, left=True +) +eig_right = _partial( + sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False, right=True +) __all__ += _append("eig", ["", "_left", "_right"]) eig_destroy = _partial(sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True) -eig_left_destroy = _partial(sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True, left=True) -eig_right_destroy = _partial(sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True, right=True) +eig_left_destroy = _partial( + sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True, left=True +) +eig_right_destroy = _partial( + sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True, right=True +) __all__ += _append("eig_", ["destroy", "left_destroy", "right_destroy"]) eigvals = _partial(sl.eigvals, check_finite=False, overwrite_a=False) @@ -268,8 +283,12 @@ def _append(name, suffix): eigh_destroy = _partial(sl.eigh, check_finite=False, overwrite_a=True, overwrite_b=True) __all__ += _append("eigh", ["", "_destroy"]) -eigvalsh = _partial(sl.eigvalsh, check_finite=False, overwrite_a=False, overwrite_b=False) -eigvalsh_destroy = _partial(sl.eigvalsh, check_finite=False, overwrite_a=True, overwrite_b=True) +eigvalsh = _partial( + sl.eigvalsh, check_finite=False, overwrite_a=False, overwrite_b=False +) +eigvalsh_destroy = _partial( + sl.eigvalsh, check_finite=False, overwrite_a=True, overwrite_b=True +) __all__ += _append("eigvalsh", ["", "_destroy"]) cholesky = _partial(sl.cholesky, check_finite=False) diff --git a/src/sisl/linalg/special.py b/src/sisl/linalg/special.py index 21a6e5f580..e14574714c 100644 --- a/src/sisl/linalg/special.py +++ b/src/sisl/linalg/special.py @@ -5,11 +5,11 @@ from .base import eigh -__all__ = ['signsqrt', 'sqrth', 'invsqrth', 'lowdin'] +__all__ = ["signsqrt", "sqrth", "invsqrth", "lowdin"] def signsqrt(a): - r""" Calculate the sqrt of the elements `a` by retaining the sign. + r"""Calculate the sqrt of the elements `a` by retaining the sign. This only influences negative values in `a` by returning ``-abs(a)**0.5`` @@ -37,7 +37,7 @@ def sqrth(a, overwrite_a=False): def invsqrth(a, overwrite_a=False): - """ Calculate the inverse sqrt of the Hermitian matrix `H` + """Calculate the inverse sqrt of the Hermitian matrix `H` We do this by using eigh and taking the sqrt of the eigenvalues. @@ -52,7 +52,7 @@ def invsqrth(a, overwrite_a=False): def lowdin(a, b, overwrite_a=False): - r""" Convert the matrix `b` in the basis `a` into an orthogonal basis using the Lowdin transformation + r"""Convert the matrix `b` in the basis `a` into an orthogonal basis using the Lowdin transformation .. math:: diff --git a/src/sisl/linalg/tests/test_solve.py b/src/sisl/linalg/tests/test_solve.py index cab1901af8..eb4e9f16f0 100644 --- a/src/sisl/linalg/tests/test_solve.py +++ b/src/sisl/linalg/tests/test_solve.py @@ -28,7 +28,7 @@ def test_solve2(): xs = sl.solve(a, b) x = solve(a, b) assert np.allclose(xs, x) - assert x.shape == (10, ) + assert x.shape == (10,) assert np.allclose(a, ac) assert np.allclose(b, bc) @@ -48,4 +48,4 @@ def test_solve4(): xs = sl.solve(a, b) x = solve_destroy(a, b) assert np.allclose(xs, x) - assert x.shape == (10, ) + assert x.shape == (10,) diff --git a/src/sisl/messages.py b/src/sisl/messages.py index f96425f76e..4d35823700 100644 --- a/src/sisl/messages.py +++ b/src/sisl/messages.py @@ -26,9 +26,9 @@ from ._environ import get_environ_variable from ._internal import set_module -__all__ = ['SislDeprecation', 'SislInfo', 'SislWarning', 'SislException', 'SislError'] -__all__ += ['warn', 'info', 'deprecate', "deprecation", "deprecate_argument"] -__all__ += ['progressbar', 'tqdm_eta'] +__all__ = ["SislDeprecation", "SislInfo", "SislWarning", "SislException", "SislError"] +__all__ += ["warn", "info", "deprecate", "deprecation", "deprecate_argument"] +__all__ += ["progressbar", "tqdm_eta"] # The local registry for warnings issued _sisl_warn_registry = {} @@ -36,37 +36,42 @@ @set_module("sisl") class SislException(Exception): - """ Sisl exception """ + """Sisl exception""" + pass @set_module("sisl") class SislError(SislException): - """ Sisl error """ + """Sisl error""" + pass @set_module("sisl") class SislWarning(SislException, UserWarning): - """ Sisl warnings """ + """Sisl warnings""" + pass @set_module("sisl") class SislDeprecation(SislWarning, FutureWarning): - """ Sisl deprecation warnings for end-users """ + """Sisl deprecation warnings for end-users""" + pass @set_module("sisl") class SislInfo(SislWarning): - """ Sisl informations """ + """Sisl informations""" + pass @set_module("sisl") def deprecate(message, from_version=None): - """ Issue sisl deprecation warnings + """Issue sisl deprecation warnings Parameters ---------- @@ -77,17 +82,20 @@ def deprecate(message, from_version=None): """ if from_version is not None: message = f"{message} [>={from_version}]" - warnings.warn_explicit(message, SislDeprecation, 'dep', 0, registry=_sisl_warn_registry) + warnings.warn_explicit( + message, SislDeprecation, "dep", 0, registry=_sisl_warn_registry + ) @set_module("sisl") def deprecate_argument(old, new, message, from_version=None): - """ Decorator for deprecating `old` argument, and replacing it with `new` + """Decorator for deprecating `old` argument, and replacing it with `new` The old keyword argument is still retained. If `new` is none, it will be deleted. """ + def deco(func): @wraps(func) def wrapped(*args, **kwargs): @@ -96,13 +104,15 @@ def wrapped(*args, **kwargs): if new is not None: kwargs[new] = kwargs.pop(old) return func(*args, **kwargs) + return wrapped + return deco @set_module("sisl") def deprecation(message, from_version=None): - """ Decorator for deprecating a method or a class + """Decorator for deprecating a method or a class Parameters ---------- @@ -111,6 +121,7 @@ def deprecation(message, from_version=None): from_version : optional which version to deprecate this method from """ + def install_deprecate(cls_or_func): if isinstance(cls_or_func, type): # we have a class @@ -118,18 +129,22 @@ class wrapped(cls_or_func): @deprecation(message, from_version) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + else: + @wraps(cls_or_func) def wrapped(*args, **kwargs): deprecate(message, from_version) return cls_or_func(*args, **kwargs) + return wrapped + return install_deprecate @set_module("sisl") def warn(message, category=None, register=False): - """ Show warnings in short context form with sisl + """Show warnings in short context form with sisl Parameters ---------- @@ -146,14 +161,16 @@ def warn(message, category=None, register=False): elif category is None: category = SislWarning if register: - warnings.warn_explicit(message, category, 'warn', 0, registry=_sisl_warn_registry) + warnings.warn_explicit( + message, category, "warn", 0, registry=_sisl_warn_registry + ) else: - warnings.warn_explicit(message, category, 'warn', 0) + warnings.warn_explicit(message, category, "warn", 0) @set_module("sisl") def info(message, category=None, register=False): - """ Show info in short context form with sisl + """Show info in short context form with sisl Parameters ---------- @@ -170,24 +187,27 @@ def info(message, category=None, register=False): elif category is None: category = SislInfo if register: - warnings.warn_explicit(message, category, 'info', 0, registry=_sisl_warn_registry) + warnings.warn_explicit( + message, category, "info", 0, registry=_sisl_warn_registry + ) else: - warnings.warn_explicit(message, category, 'info', 0) + warnings.warn_explicit(message, category, "info", 0) # https://stackoverflow.com/a/39662359/827281 def is_jupyter_notebook(): try: shell = get_ipython().__class__.__name__ - if shell == 'ZMQInteractiveShell': + if shell == "ZMQInteractiveShell": return True - elif shell == 'TerminalInteractiveShell': + elif shell == "TerminalInteractiveShell": return False else: return False except NameError: return False + # Figure out if we can import tqdm. # If so, simply use the progressbar class there. # Otherwise, create a fake one. @@ -198,14 +218,18 @@ def is_jupyter_notebook(): from tqdm import tqdm as _tqdm except ImportError: # Notify user of better option - info('Please install tqdm (pip install tqdm) for better looking progress bars', register=True) + info( + "Please install tqdm (pip install tqdm) for better looking progress bars", + register=True, + ) # Necessary methods used from sys import stdout as _stdout from time import time as _time class _tqdm: - """ Fake tqdm progress-bar. I should update this to also work in regular instances """ + """Fake tqdm progress-bar. I should update this to also work in regular instances""" + __slots__ = ["total", "desc", "t0", "n", "l"] def __init__(self, total, desc, unit): @@ -228,7 +252,9 @@ def update(self, n=1): def close(self): m, s = divmod(_time() - self.t0, 60) h, m = divmod(m, 60) - _stdout.write(f"{self.desc} finished after {int(h):d}h {int(m):d}m {s:.1f}s\r") + _stdout.write( + f"{self.desc} finished after {int(h):d}h {int(m):d}m {s:.1f}s\r" + ) _stdout.flush() @@ -237,7 +263,7 @@ def close(self): @set_module("sisl") def progressbar(total, desc, unit, eta, **kwargs): - """ Create a progress bar in when it is requested. Otherwise returns a fake object + """Create a progress bar in when it is requested. Otherwise returns a fake object Parameters ---------- @@ -268,9 +294,12 @@ def progressbar(total, desc, unit, eta, **kwargs): # has the required 2 methods, update and close. class Fake: __slots__ = [] + def update(self, n=1): pass + def close(self): pass + bar = Fake() return bar diff --git a/src/sisl/mixing/base.py b/src/sisl/mixing/base.py index c708a5acbf..3630d2d030 100644 --- a/src/sisl/mixing/base.py +++ b/src/sisl/mixing/base.py @@ -35,12 +35,12 @@ @set_module("sisl.mixing") class BaseMixer: - r""" Base class mixer """ + r"""Base class mixer""" __slots__ = () @abstractmethod def __call__(self, f: T, df: T, *args: Any, **kwargs: Any) -> T: - """ Mix quantities based on arguments """ + """Mix quantities based on arguments""" def __add__(self, other: Union[float, int, TypeBaseMixer]) -> TypeCompositeMixer: return CompositeMixer(op.add, self, other) @@ -60,10 +60,14 @@ def __mul__(self, factor: Union[float, int, TypeBaseMixer]) -> TypeCompositeMixe def __rmul__(self, factor: Union[float, int, TypeBaseMixer]) -> TypeCompositeMixer: return CompositeMixer(op.mul, self, factor) - def __truediv__(self, divisor: Union[float, int, TypeBaseMixer]) -> TypeCompositeMixer: + def __truediv__( + self, divisor: Union[float, int, TypeBaseMixer] + ) -> TypeCompositeMixer: return CompositeMixer(op.truediv, self, divisor) - def __rtruediv__(self, divisor: Union[float, int, TypeBaseMixer]) -> TypeCompositeMixer: + def __rtruediv__( + self, divisor: Union[float, int, TypeBaseMixer] + ) -> TypeCompositeMixer: return CompositeMixer(op.truediv, divisor, self) def __neg__(self) -> TypeCompositeMixer: @@ -78,7 +82,7 @@ def __rpow__(self, other: Union[float, int, TypeBaseMixer]) -> TypeCompositeMixe @set_module("sisl.mixing") class CompositeMixer(BaseMixer): - """ Placeholder for two metrics """ + """Placeholder for two metrics""" __slots__ = ("_op", "A", "B") @@ -100,11 +104,11 @@ def __call__(self, f: T, df: T, *args: Any, **kwargs: Any) -> T: def __str__(self) -> str: if isinstance(self.A, BaseMixer): - A = "({})".format(repr(self.A).replace('\n', '\n ')) + A = "({})".format(repr(self.A).replace("\n", "\n ")) else: A = f"{self.A}" if isinstance(self.B, BaseMixer): - B = "({})".format(repr(self.B).replace('\n', '\n ')) + B = "({})".format(repr(self.B).replace("\n", "\n ")) else: B = f"{self.B}" return f"{self.__class__.__name__}{{{self._op.__name__}({A}, {B})}}" @@ -112,7 +116,7 @@ def __str__(self) -> str: @set_module("sisl.mixing") class BaseWeightMixer(BaseMixer): - r""" Base class mixer """ + r"""Base class mixer""" __slots__ = ("_weight",) def __init__(self, weight: TypeWeight = 0.2): @@ -120,11 +124,11 @@ def __init__(self, weight: TypeWeight = 0.2): @property def weight(self) -> TypeWeight: - """ This mixers mixing weight, the weight is the fractional contribution of the derivative """ + """This mixers mixing weight, the weight is the fractional contribution of the derivative""" return self._weight def set_weight(self, weight: TypeWeight): - """ Set a new weight for this mixer + """Set a new weight for this mixer Parameters ---------- @@ -137,7 +141,7 @@ def set_weight(self, weight: TypeWeight): @set_module("sisl.mixing") class BaseHistoryWeightMixer(BaseWeightMixer): - r""" Base class mixer with history """ + r"""Base class mixer with history""" __slots__ = ("_history",) def __init__(self, weight: TypeWeight = 0.2, history: TypeArgHistory = 0): @@ -145,19 +149,18 @@ def __init__(self, weight: TypeWeight = 0.2, history: TypeArgHistory = 0): self.set_history(history) def __str__(self) -> str: - r""" String representation """ + r"""String representation""" hist = str(self.history).replace("\n", "\n ") return f"{self.__class__.__name__}{{weight: {self.weight:.4f},\n {hist}\n}}" def __repr__(self) -> str: - r""" String representation """ + r"""String representation""" hist = len(self.history) max_hist = self.history.max_elements return f"{self.__class__.__name__}{{weight: {self.weight:.4f}, history={hist}|{max_hist}}}" - def __call__(self, f: T, df: T, *args: Any, append: bool = True) -> None: - """ Append data to the history (omitting None values)! """ + """Append data to the history (omitting None values)!""" if not append: # do nothing return @@ -170,11 +173,11 @@ def __call__(self, f: T, df: T, *args: Any, append: bool = True) -> None: @property def history(self) -> TypeHistory: - """ History object tracked by this mixer """ + """History object tracked by this mixer""" return self._history def set_history(self, history: TypeArgHistory) -> None: - """ Replace the current history in the mixer with a new one + """Replace the current history in the mixer with a new one Parameters ---------- @@ -189,7 +192,7 @@ def set_history(self, history: TypeArgHistory) -> None: @set_module("sisl.mixing") class StepMixer(BaseMixer): - """ Step between different mixers in a user-defined fashion + """Step between different mixers in a user-defined fashion This is handy for creating variable mixing schemes that alternates (or differently) between multiple mixers. @@ -235,7 +238,7 @@ def __init__(self, *yield_funcs: TypeStepCallable): self._mixer = next(self._yield_mixer) def next(self) -> TypeBaseMixer: - """ Return the current mixer, and step the internal mixer """ + """Return the current mixer, and step the internal mixer""" mixer = self._mixer try: self._mixer = next(self._yield_mixer) @@ -247,15 +250,15 @@ def next(self) -> TypeBaseMixer: @property def mixer(self) -> TypeBaseMixer: - """ Return the current mixer """ + """Return the current mixer""" return self._mixer def __call__(self, f: T, df: T, *args: Any, **kwargs: Any) -> T: - """ Apply the mixing routine """ + """Apply the mixing routine""" return self.next()(f, df, *args, **kwargs) def __getattr__(self, attr: str) -> Any: - """ Divert all unknown attributes to the current mixer + """Divert all unknown attributes to the current mixer Note that available attributes may be different for different mixers. @@ -263,22 +266,30 @@ def __getattr__(self, attr: str) -> Any: return getattr(self.mixer, attr) @classmethod - def yield_repeat(cls: TypeStepMixer, mixer: TypeBaseMixer, n: int) -> TypeStepCallable: - """ Returns a function which repeats `mixer` `n` times """ + def yield_repeat( + cls: TypeStepMixer, mixer: TypeBaseMixer, n: int + ) -> TypeStepCallable: + """Returns a function which repeats `mixer` `n` times""" if n == 1: + def yield_repeat() -> Iterator[TypeBaseMixer]: - f""" Yield the mixer {mixer} 1 time """ + f"""Yield the mixer {mixer} 1 time""" yield mixer + else: + def yield_repeat() -> Iterator[TypeBaseMixer]: - f""" Yield the mixer {mixer} {n} times """ + f"""Yield the mixer {mixer} {n} times""" for _ in range(n): yield mixer + return yield_repeat @classmethod - def yield_chain(cls: TypeStepMixer, *yield_funcs: TypeStepCallable) -> TypeStepCallable: - """ Returns a function which yields from each of the function arguments in turn + def yield_chain( + cls: TypeStepMixer, *yield_funcs: TypeStepCallable + ) -> TypeStepCallable: + """Returns a function which yields from each of the function arguments in turn Basically equivalent to a function which does this: @@ -292,16 +303,18 @@ def yield_chain(cls: TypeStepMixer, *yield_funcs: TypeStepCallable) -> TypeStepC """ if len(yield_funcs) == 1: return yield_funcs[0] + def yield_chain() -> Iterator[TypeBaseMixer]: - f""" Yield from the different yield generators """ + f"""Yield from the different yield generators""" for yield_func in yield_funcs: yield from yield_func() + return yield_chain @set_module("sisl.mixing") class History: - r""" A history class for retaining a set of history elements + r"""A history class for retaining a set of history elements A history class may contain several different variables in a `collections.deque` list allowing easy managing of the length of the history. @@ -322,17 +335,19 @@ def __init__(self, history: int = 2): self._hist = deque(maxlen=history) def __str__(self) -> str: - """ str of the object """ - return f"{self.__class__.__name__}{{history: {self.elements}/{self.max_elements}}}" + """str of the object""" + return ( + f"{self.__class__.__name__}{{history: {self.elements}/{self.max_elements}}}" + ) @property def max_elements(self) -> int: - r""" Maximum number of elements stored in the history for each variable """ + r"""Maximum number of elements stored in the history for each variable""" return self._hist.maxlen @property def elements(self) -> int: - r""" Number of elements in the history """ + r"""Number of elements in the history""" return len(self._hist) def __len__(self) -> int: @@ -348,7 +363,7 @@ def __delitem__(self, key: Union[int, ArrayLike]) -> None: self.clear(key) def append(self, *variables: Any) -> None: - r""" Add variables to the history + r"""Add variables to the history Internally, the list of variables will be added to the queue, it is up to the implementation to use the appended values. @@ -360,8 +375,8 @@ def append(self, *variables: Any) -> None: """ self._hist.append(variables) - def clear(self, index: Optional[Union[int, ArrayLike]]=None) -> None: - r""" Clear variables to the history + def clear(self, index: Optional[Union[int, ArrayLike]] = None) -> None: + r"""Clear variables to the history Parameters ---------- diff --git a/src/sisl/mixing/diis.py b/src/sisl/mixing/diis.py index aa6bd36040..7a4505e550 100644 --- a/src/sisl/mixing/diis.py +++ b/src/sisl/mixing/diis.py @@ -70,17 +70,23 @@ class DIISMixer(BaseHistoryWeightMixer): """ __slots__ = ("_metric",) - def __init__(self, weight: TypeWeight = 0.1, history: TypeArgHistory = 2, - metric: Optional[TypeMetric] = None): + def __init__( + self, + weight: TypeWeight = 0.1, + history: TypeArgHistory = 2, + metric: Optional[TypeMetric] = None, + ): # This will call self.set_history(history) super().__init__(weight, history) if metric is None: + def metric(a, b): return a.ravel().conj().dot(b.ravel()).real + self._metric = metric def solve_lagrange(self) -> Tuple[NDArray, NDArray]: - r""" Calculate the coefficients according to Pulay's method, return everything + Lagrange multiplier """ + r"""Calculate the coefficients according to Pulay's method, return everything + Lagrange multiplier""" hist = self.history n_h = len(hist) metric = self._metric @@ -88,9 +94,9 @@ def solve_lagrange(self) -> Tuple[NDArray, NDArray]: if n_h == 0: # Externally the coefficients should reflect the weight per previous iteration. # The mixing weight is an additional parameter - return _a.arrayd([1.]), 100. + return _a.arrayd([1.0]), 100.0 elif n_h == 1: - return _a.arrayd([1.]), metric(hist[0][-1], hist[0][-1]) + return _a.arrayd([1.0]), metric(hist[0][-1], hist[0][-1]) # Initialize the matrix to be solved against B = _a.emptyd([n_h + 1, n_h + 1]) @@ -104,14 +110,14 @@ def solve_lagrange(self) -> Tuple[NDArray, NDArray]: B[i, j] = metric(ei, ej) B[j, i] = B[i, j] - B[:, n_h] = 1. - B[n_h, :] = 1. - B[n_h, n_h] = 0. + B[:, n_h] = 1.0 + B[n_h, :] = 1.0 + B[n_h, n_h] = 0.0 # Although B contains 1 and a number on the order of # number of elements (self._hist.size), it seems very # numerically stable. - last_metric = B[n_h-1, n_h-1] + last_metric = B[n_h - 1, n_h - 1] # Create RHS RHS = _a.zerosd(n_h + 1) @@ -125,28 +131,30 @@ def solve_lagrange(self) -> Tuple[NDArray, NDArray]: return c[:-1], -c[-1] except np.linalg.LinAlgError as e: # We have a LinalgError - return _a.arrayd([1.]), last_metric + return _a.arrayd([1.0]), last_metric def coefficients(self) -> NDArray: - r""" Calculate coefficients of the Lagrangian """ + r"""Calculate coefficients of the Lagrangian""" c, lagrange = self.solve_lagrange() return c def mix(self, coefficients: NDArray) -> Any: - r""" Calculate a new variable :math:`f'` using history and input coefficients + r"""Calculate a new variable :math:`f'` using history and input coefficients Parameters ---------- coefficients : numpy.ndarray coefficients used for extrapolation """ + def frac_hist(coef, hist): return coef * (hist[0] + self.weight * hist[1]) + return reduce(add, map(frac_hist, coefficients, self.history)) - def __call__(self, f: T, df: T, - delta: Optional[Any] = None, - append: bool = True) -> T: + def __call__( + self, f: T, df: T, delta: Optional[Any] = None, append: bool = True + ) -> T: # Add to history super().__call__(f, df, delta, append=append) @@ -154,12 +162,12 @@ def __call__(self, f: T, df: T, return self.mix(self.coefficients()) -PulayMixer = set_module("sisl.mixing")(type("PulayMixer", (DIISMixer, ), {})) +PulayMixer = set_module("sisl.mixing")(type("PulayMixer", (DIISMixer,), {})) @set_module("sisl.mixing") class AdaptiveDIISMixer(DIISMixer): - r""" Adapt the mixing weight according to the Lagrange multiplier + r"""Adapt the mixing weight according to the Lagrange multiplier The Lagrange multiplier calculated in a DIIS/Pulay mixing scheme is the squared norm of the residual that is minimized using the @@ -173,19 +181,25 @@ class AdaptiveDIISMixer(DIISMixer): """ __slots__ = ("_weight_min", "_weight_delta") - def __init__(self, weight: Tuple[TypeWeight, TypeWeight] = (0.03, 0.5), - history: TypeArgHistory = 2, - metric: Optional[TypeMetric] = None): + def __init__( + self, + weight: Tuple[TypeWeight, TypeWeight] = (0.03, 0.5), + history: TypeArgHistory = 2, + metric: Optional[TypeMetric] = None, + ): if isinstance(weight, Real): - weight = (max(0.001, weight * 0.1), min(1., weight * 2)) + weight = (max(0.001, weight * 0.1), min(1.0, weight * 2)) super().__init__(weight[0], history, metric) self._weight_min = weight[0] self._weight_delta = weight[1] - weight[0] - def adjust_weight(self, lagrange: Any, - offset: Union[float, int] = 13, - spread: Union[float, int] = 7) -> None: - r""" Adjust the weight according to the Lagrange multiplier. + def adjust_weight( + self, + lagrange: Any, + offset: Union[float, int] = 13, + spread: Union[float, int] = 7, + ) -> None: + r"""Adjust the weight according to the Lagrange multiplier. Once close to convergence the Lagrange multiplier will be close to 0, otherwise it will go towards infinity. @@ -196,10 +210,12 @@ def adjust_weight(self, lagrange: Any, self._weight = self._weight_min + self._weight_delta / (exp_lag_log + 1) def coefficients(self) -> NDArray: - r""" Calculate coefficients and adjust weights according to a Lagrange multiplier """ + r"""Calculate coefficients and adjust weights according to a Lagrange multiplier""" c, lagrange = self.solve_lagrange() self.adjust_weight(lagrange) return c -AdaptivePulayMixer = set_module("sisl.mixing")(type("AdaptivePulayMixer", (AdaptiveDIISMixer, ), {})) +AdaptivePulayMixer = set_module("sisl.mixing")( + type("AdaptivePulayMixer", (AdaptiveDIISMixer,), {}) +) diff --git a/src/sisl/mixing/linear.py b/src/sisl/mixing/linear.py index 192b6d9e9f..d4b0be508b 100644 --- a/src/sisl/mixing/linear.py +++ b/src/sisl/mixing/linear.py @@ -15,7 +15,7 @@ @set_module("sisl.mixing") class LinearMixer(BaseHistoryWeightMixer): - r""" Linear mixing + r"""Linear mixing The linear mixing is solely defined using a weight, and the resulting functional may then be calculated via: @@ -31,8 +31,8 @@ class LinearMixer(BaseHistoryWeightMixer): """ __slots__ = () - def __call__(self, f: T, df: T, append: bool=True) -> T: - r""" Calculate a new variable :math:`f'` using input and output of the functional + def __call__(self, f: T, df: T, append: bool = True) -> T: + r"""Calculate a new variable :math:`f'` using input and output of the functional Parameters ---------- @@ -92,10 +92,10 @@ def metric(a, b): return beta - def __call__(self, f: T, df: T, - delta: Optional[Any]=None, - append: bool=True) -> T: - r""" Calculate a new variable :math:`f'` using input and output of the functional + def __call__( + self, f: T, df: T, delta: Optional[Any] = None, append: bool = True + ) -> T: + r"""Calculate a new variable :math:`f'` using input and output of the functional Parameters ---------- diff --git a/src/sisl/mixing/tests/test_linear.py b/src/sisl/mixing/tests/test_linear.py index a9ae3cd770..771a873b07 100644 --- a/src/sisl/mixing/tests/test_linear.py +++ b/src/sisl/mixing/tests/test_linear.py @@ -63,10 +63,7 @@ def scf(f): f = mix(f, df) -@pytest.mark.parametrize("op", - [op.add, op.sub, - op.mul, op.truediv, - op.pow]) +@pytest.mark.parametrize("op", [op.add, op.sub, op.mul, op.truediv, op.pow]) def test_composite_mixer_init(op): mix1 = AndersonMixer() mix2 = LinearMixer() diff --git a/src/sisl/nodes/context.py b/src/sisl/nodes/context.py index 4579930fa7..bac0a2f4c6 100644 --- a/src/sisl/nodes/context.py +++ b/src/sisl/nodes/context.py @@ -10,27 +10,28 @@ lazy_init=None, # The level of logs stored in the node. log_level="INFO", - # Whether to raise a custom error exception (e.g. NodeCalcError) By default + # Whether to raise a custom error exception (e.g. NodeCalcError) By default # it is turned off because it can obscure the real problem by not showing it # in the last traceback frame. raise_custom_errors=False, ) # Temporal contexts stack. It should not be used directly by users, the aim of this -# stack is to populate it when context managers are used. This is a chainmap and +# stack is to populate it when context managers are used. This is a chainmap and # not a simple dict because we might have nested context managers. _TEMPORAL_CONTEXTS = ChainMap() + class NodeContext(ChainMap): """Extension of Chainmap that always checks on the temporal context first. - + Using this class is equivalent to forcing users to have the temporal context always in the first position of the chainmap. Since this is not a very nice thing to force on users, we use this class instead. Keys: lazy: bool - If `False`, nodes will automatically recompute if any of their inputs + If `False`, nodes will automatically recompute if any of their inputs have changed, even if no other node needs their output yet. lazy_init: bool or None Whether the node should compute on initialization. If None, defaults to @@ -47,6 +48,7 @@ def __getitem__(self, key: str): else: return super().__getitem__(key) + @contextlib.contextmanager def temporal_context(context: Union[dict, ChainMap, None] = None, **context_keys: Any): """Sets a context temporarily (until the context manager is exited). @@ -56,8 +58,8 @@ def temporal_context(context: Union[dict, ChainMap, None] = None, **context_keys context: dict or ChainMap, optional The context that should be updated temporarily. This could for example be sisl's main context or the context of a specific node class. - - If None, the keys and values are forced on all nodes. + + If None, the keys and values are forced on all nodes. **context_keys: Any The keys and values that should be used for the nodes context. @@ -88,6 +90,7 @@ def temporal_context(context: Union[dict, ChainMap, None] = None, **context_keys def _restore(): # Restore the original context. context.update(old_context) + else: # Add this temporal context on top of the temporal contexts stack. _TEMPORAL_CONTEXTS.maps.insert(0, context_keys) @@ -103,4 +106,4 @@ def _restore(): except Exception as e: # The block has raised an exception, restore the context and re-raise. _restore() - raise e \ No newline at end of file + raise e diff --git a/src/sisl/nodes/dispatcher.py b/src/sisl/nodes/dispatcher.py index 0cb7bbb65c..3de8c8a77e 100644 --- a/src/sisl/nodes/dispatcher.py +++ b/src/sisl/nodes/dispatcher.py @@ -3,7 +3,6 @@ class Dispatcher(Node): - _dispatchs = {} _node = None _dispatch_input = "key" @@ -13,9 +12,8 @@ def __init_subclass__(cls): cls._node = None cls._default_dispatch = None return super().__init_subclass__() - - def _get(self, *args, **kwargs): + def _get(self, *args, **kwargs): key = kwargs.pop(self._dispatch_input, None) if key is None: @@ -25,13 +23,15 @@ def _get(self, *args, **kwargs): kls = key else: if key not in self._dispatchs: - raise ValueError(f"Registered nodes have keys: {list(self._dispatchs)}, but {key} was requested") - + raise ValueError( + f"Registered nodes have keys: {list(self._dispatchs)}, but {key} was requested" + ) + kls = self._dispatchs[key] - + with lazy_context(nodes=True): self._node = kls(*args, **kwargs) - + return self._node.get() def __getattr__(self, key): @@ -39,10 +39,10 @@ def __getattr__(self, key): if self._node is None: self.get() return getattr(self._node, key) - + @classmethod def register(cls, key, node_cls, default=False): if default or len(cls._dispatchs) == 0: cls._default_dispatch = key - cls._dispatchs[key] = node_cls \ No newline at end of file + cls._dispatchs[key] = node_cls diff --git a/src/sisl/nodes/node.py b/src/sisl/nodes/node.py index 511887b51c..d3c3a031ea 100644 --- a/src/sisl/nodes/node.py +++ b/src/sisl/nodes/node.py @@ -17,29 +17,29 @@ class NodeError(SislError): def __init__(self, node, error): self._node = node self._error = error - + def __str__(self): return f"There was an error with node {self._node}. {self._error}" class NodeCalcError(NodeError): - def __init__(self, node, error, inputs): super().__init__(node, error) self._inputs = inputs - + def __str__(self): - return (f"Couldn't generate an output for {self._node} with the current inputs.") + return f"Couldn't generate an output for {self._node} with the current inputs." + class NodeInputError(NodeError): - def __init__(self, node, error, inputs): super().__init__(node, error) self._inputs = inputs - + def __str__(self): # Should make this more specific - return (f"Some input is not right in {self._node} and could not be parsed") + return f"Some input is not right in {self._node} and could not be parsed" + class Node(NDArrayOperatorsMixin): """Generic class for nodes. @@ -47,11 +47,12 @@ class Node(NDArrayOperatorsMixin): A node is a process that runs with some inputs and returns some outputs. Inputs can come from other nodes' outputs, and therefore outputs can be linked to another node's inputs. - + A node MUST be a pure function. That is, the output should only depend on the input. In that way, the output of the node only needs to be calculated when the inputs change. """ + # Object that will be the reference for output that has not been returned. _blank = object() # This is the signal to remove a kwarg from the inputs. @@ -77,7 +78,7 @@ class Node(NDArrayOperatorsMixin): # Current output value of the node _output: Any = _blank - + # Nodes that are connected to this node's inputs _input_nodes: Dict[str, Node] # Nodes to which the output of this node is connected @@ -101,12 +102,11 @@ class Node(NDArrayOperatorsMixin): function: Callable def __init__(self, *args, **kwargs): - self.setup(*args, **kwargs) - lazy_init = self.context['lazy_init'] + lazy_init = self.context["lazy_init"] if lazy_init is None: - lazy_init = self.context['lazy'] + lazy_init = self.context["lazy"] if not lazy_init: self.get() @@ -114,7 +114,7 @@ def __init__(self, *args, **kwargs): def __call__(self, *args, **kwargs): self.update_inputs(*args, **kwargs) return self.get() - + def setup(self, *args, **kwargs): """Sets up the node based on its initial inputs.""" # Parse inputs into arguments. @@ -137,14 +137,14 @@ def setup(self, *args, **kwargs): self._errored = False self._error = None - self._logger = logging.getLogger( - str(id(self)) + self._logger = logging.getLogger(str(id(self))) + self._log_formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)-8s :: %(message)s" ) - self._log_formatter = logging.Formatter(fmt='%(asctime)s | %(levelname)-8s :: %(message)s') self.logs = "" self.context = self.__class__.context.new_child({}) - + def __init_subclass__(cls): # Assign a context to this node class. This is a chainmap that will # resolve keys from its parents, in the order defined by the MRO, in @@ -165,25 +165,31 @@ def __init_subclass__(cls): # If the class doesn't contain a "function" attribute, it means that it is just meant # to be a base class. If it does contain a "function" attribute, it is an actual usable # node class that implements some computation. In that case, we modify the signature of the - # class to mimic the signature of the function. + # class to mimic the signature of the function. if hasattr(cls, "function"): node_func = cls.function # Get the signature of the function sig = inspect.signature(node_func) - + cls.__doc__ = node_func.__doc__ # Use the function's signature for the __init__ function, so that the help message # is actually useful. init_sig = sig if "self" not in init_sig.parameters: - init_sig = sig.replace(parameters=[ - inspect.Parameter("self", kind=inspect.Parameter.POSITIONAL_ONLY), - *sig.parameters.values() - ]) - - no_self_sig = init_sig.replace(parameters=tuple(init_sig.parameters.values())[1:]) + init_sig = sig.replace( + parameters=[ + inspect.Parameter( + "self", kind=inspect.Parameter.POSITIONAL_ONLY + ), + *sig.parameters.values(), + ] + ) + + no_self_sig = init_sig.replace( + parameters=tuple(init_sig.parameters.values())[1:] + ) # Find out if there are arguments that are VAR_POSITIONAL (*args) or VAR_KEYWORD (**kwargs) # and register it so that they can be handled on init. @@ -198,17 +204,19 @@ def __init_subclass__(cls): cls.__signature__ = no_self_sig return super().__init_subclass__() - + @classmethod - def from_func(cls, func: Union[Callable, None] = None, context: Union[dict, None] = None): + def from_func( + cls, func: Union[Callable, None] = None, context: Union[dict, None] = None + ): """Builds a node from a function. Parameters ---------- func: function, optional - The function to be converted to a node. - - If not provided, the return of this method is just a lambda function that expects + The function to be converted to a node. + + If not provided, the return of this method is just a lambda function that expects the function. This is useful if you want to use this method as a decorator while also providing extra arguments (like the context argument). context: dict, optional @@ -227,13 +235,17 @@ def from_func(cls, func: Union[Callable, None] = None, context: Union[dict, None return CallableNode(func=node) if func in cls._known_function_nodes: - return cls._known_function_nodes[func] - - new_node_cls = type(func.__name__, (cls, ), { - "function": staticmethod(func), - "_cls_context": context, - "_from_function": True - }) + return cls._known_function_nodes[func] + + new_node_cls = type( + func.__name__, + (cls,), + { + "function": staticmethod(func), + "_cls_context": context, + "_from_function": True, + }, + ) cls._known_function_nodes[func] = new_node_cls @@ -259,10 +271,10 @@ def _is_equal(prev, curr): try: if prev == curr: return True - return False + return False except: return False - + # Otherwise, check if the inputs remain the same. for key in self._prev_evaluated_inputs: # Get the previous and current values @@ -274,7 +286,13 @@ def _is_equal(prev, curr): return False - def map_inputs(self, inputs: Dict[str, Any], func: Callable, only_nodes: bool = False, exclude: Sequence[str] = ()) -> Dict[str, Any]: + def map_inputs( + self, + inputs: Dict[str, Any], + func: Callable, + only_nodes: bool = False, + exclude: Sequence[str] = (), + ) -> Dict[str, Any]: """Maps all inputs of the node applying a given function. It considers the args and kwargs keys. @@ -299,7 +317,7 @@ def map_inputs(self, inputs: Dict[str, Any], func: Callable, only_nodes: bool = if key in exclude: mapped[key] = input_val continue - + # For the special args_inputs key (if any), we need to loop through all the items if key == self._args_inputs_key: input_val = tuple( @@ -307,7 +325,8 @@ def map_inputs(self, inputs: Dict[str, Any], func: Callable, only_nodes: bool = for val in input_val ) elif key == self._kwargs_inputs_key: - input_val = {k: func(val) if not only_nodes or isinstance(val, Node) else val + input_val = { + k: func(val) if not only_nodes or isinstance(val, Node) else val for k, val in input_val.items() } else: @@ -316,12 +335,14 @@ def map_inputs(self, inputs: Dict[str, Any], func: Callable, only_nodes: bool = input_val = func(input_val) mapped[key] = input_val - + return mapped - def _sanitize_inputs(self, inputs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + def _sanitize_inputs( + self, inputs: Dict[str, Any] + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """Converts a dictionary that may contain args and kwargs keys to a tuple of args and a dictionary of kwargs. - + Parameters ---------- inputs : Dict[str, Any] @@ -336,7 +357,7 @@ def _sanitize_inputs(self, inputs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dic kwargs.update(kwargs_inputs) return args, kwargs - + @staticmethod def evaluate_input_node(node: Node): return node.get() @@ -344,7 +365,7 @@ def evaluate_input_node(node: Node): def get(self): # Map all inputs to their values. That is, if they are nodes, call the get # method on them so that we get the updated output. This recursively evaluates nodes. - self._logger.setLevel(getattr(logging, self.context['log_level'].upper())) + self._logger.setLevel(getattr(logging, self.context["log_level"].upper())) logs = logging.StreamHandler(StringIO()) self._logger.addHandler(logs) @@ -355,7 +376,7 @@ def get(self): self._logger.debug(f"Raw inputs: {self._inputs}") evaluated_inputs = self.map_inputs( - inputs=self._inputs, + inputs=self._inputs, func=self.evaluate_input_node, only_nodes=True, ) @@ -374,12 +395,12 @@ def get(self): logs.close() self._errored = True self._error = NodeCalcError(self, e, evaluated_inputs) - - if self.context['raise_custom_errors']: + + if self.context["raise_custom_errors"]: raise self._error else: raise e - + self._nupdates += 1 self._prev_evaluated_inputs = evaluated_inputs self._outdated = False @@ -391,17 +412,18 @@ def get(self): self._logger.debug(f"Output: {self._output}.") self.logs += logs.stream.getvalue() - logs.close() + logs.close() return self._output def get_tree(self): tree = { - 'node': self, + "node": self, } - tree['inputs'] = self.map_inputs( - self._inputs, only_nodes=True, + tree["inputs"] = self.map_inputs( + self._inputs, + only_nodes=True, func=lambda node: node.get_tree(), ) @@ -419,10 +441,12 @@ def inputs(self): def get_input(self, key: str): input_val = self.inputs[key] - + return input_val - - def recursive_update_inputs(self, cls: Optional[Union[Type, Tuple[Type, ...]]] = None, **inputs): + + def recursive_update_inputs( + self, cls: Optional[Union[Type, Tuple[Type, ...]]] = None, **inputs + ): """Updates the inputs of the node recursively. This method updates the inputs of the node and all its children. @@ -435,9 +459,8 @@ def recursive_update_inputs(self, cls: Optional[Union[Type, Tuple[Type, ...]]] = The inputs to update. """ from .utils import traverse_tree_backward - - def _update(node): + def _update(node): if cls is None or isinstance(self, cls): node.update_inputs(**inputs) @@ -468,13 +491,13 @@ def update_inputs(self, **inputs): name their variadic arguments ``args``. If the function signature is ``(a: int, *arguments)`` then the key that you need to use is `arguments`. - Similarly, the **kwargs can be passed either as a dictionary in the key ``kwargs`` + Similarly, the **kwargs can be passed either as a dictionary in the key ``kwargs`` (or whatever the name of the variadic keyword arguments is). This indicates that the whole kwargs is to be replaced by the new value. Alternatively, you can pass the kwargs as separate key-value arguments, which means that you want to update the kwargs dictionary, but keep the old values. In this second option, you can indicate that a key should be removed by passing ``Node.DELETE_KWARG`` as the value. - + Parameters ---------- **inputs : @@ -483,7 +506,7 @@ def update_inputs(self, **inputs): # If no new inputs were provided, there's nothing to do if not inputs: return - + # Pop the args key (if any) so that we can parse the inputs without errors. args = None if self._args_inputs_key: @@ -500,7 +523,7 @@ def update_inputs(self, **inputs): # Now that we have parsed the inputs, put back the args key (if any). if args is not None: inputs[self._args_inputs_key] = args - + if explicit_kwargs is not None: # If a kwargs dictionary has been passed, this means that the user wants to replace # the whole kwargs dictionary. So, we just update the inputs with the new kwargs. @@ -532,27 +555,28 @@ def update_inputs(self, **inputs): def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if "out" in kwargs: - raise NotImplementedError(f"{self.__class__.__name__} does not allow the 'out' argument in ufuncs.") - inputs = {f'input_{i}': input for i, input in enumerate(inputs)} + raise NotImplementedError( + f"{self.__class__.__name__} does not allow the 'out' argument in ufuncs." + ) + inputs = {f"input_{i}": input for i, input in enumerate(inputs)} return UfuncNode(ufunc=ufunc, method=method, input_kwargs=kwargs, **inputs) def __getitem__(self, key): return GetItemNode(obj=self, key=key) - + def __getattr__(self, key): - if key.startswith('_'): + if key.startswith("_"): raise super().__getattr__(key) return GetAttrNode(obj=self, key=key) def _update_connections(self, inputs): - def _update(key, value): # Get the old connected node (if any) and tell them # that we are no longer using their input old_connection = self._input_nodes.get(key, None) if old_connection is value: # The input value has not been updated, no need to update any connections - return + return if old_connection is not None: self._input_nodes.pop(key) @@ -562,7 +586,7 @@ def _update(key, value): if isinstance(value, Node): self._input_nodes[key] = value value._receive_output_link(self) - + previous_connections = list(self._input_nodes) for key, input in inputs.items(): @@ -572,24 +596,24 @@ def _update(key, value): if not isinstance(input, DummyInputValue): input_len = len(input) for i, item in enumerate(input): - _update(f'{key}[{i}]', item) + _update(f"{key}[{i}]", item) # For indices higher than the current *args length, remove the connections. # (this is because the previous *args might have been longer) for k in previous_connections: - if k.startswith(f'{key}['): - if int(k[len(key)+1:-1]) > input_len: + if k.startswith(f"{key}["): + if int(k[len(key) + 1 : -1]) > input_len: _update(k, None) elif key == self._kwargs_inputs_key: current_kwargs = [] # Loop through all the current **kwargs to update connections if not isinstance(input, DummyInputValue): for k, item in input.items(): - connection_key = f'{key}[{k}]' + connection_key = f"{key}[{k}]" current_kwargs.append(connection_key) _update(connection_key, item) # Remove connections for those keys that are no longer in the kwargs for k in previous_connections: - if k.startswith(f'{key}[') and k not in current_kwargs: + if k.startswith(f"{key}[") and k not in current_kwargs: _update(k, None) else: # This is the normal case, where the key is not either the *args or the **kwargs key. @@ -601,22 +625,22 @@ def _receive_output_link(self, node): break else: self._output_links.append(node) - + def _receive_output_unlink(self, node): for i, linked_node in enumerate(self._output_links): if linked_node is node: del self._output_links[i] break - + def _inform_outdated(self): """Informs nodes that are linked to our output that they are outdated. - + This is either because we are outdated, or because an input update has triggered an automatic recalculation. """ for linked_node in self._output_links: linked_node._receive_outdated() - + def _receive_outdated(self): # Mark the node as outdated self._outdated = True @@ -629,48 +653,50 @@ def _receive_outdated(self): def _maybe_autoupdate(self): """Makes this node recalculate its output if automatic recalculation is turned on""" - if not self.context['lazy']: + if not self.context["lazy"]: self.get() + class DummyInputValue(Node): """A dummy node that can be used as a placeholder for input values.""" @property def input_key(self): - return self._inputs['input_key'] - + return self._inputs["input_key"] + @property def value(self): - return self._inputs.get('value', Node._blank) - + return self._inputs.get("value", Node._blank) + @staticmethod def function(input_key: str, value: Any = Node._blank): return value -class FuncNode(Node): +class FuncNode(Node): @staticmethod def function(*args, func: Callable, **kwargs): return func(*args, **kwargs) - + + class CallableNode(FuncNode): - def __call__(self, *args, **kwargs): self.update_inputs(*args, **kwargs) return self -class GetItemNode(Node): +class GetItemNode(Node): @staticmethod def function(obj: Any, key: Any): return obj[key] - -class GetAttrNode(Node): + +class GetAttrNode(Node): @staticmethod def function(obj: Any, key: str): return getattr(obj, key) + class UfuncNode(Node): """Node that wraps a numpy ufunc.""" @@ -680,7 +706,7 @@ def __call__(self, *args, **kwargs): @staticmethod def function(ufunc, method: str, input_kwargs: Dict[str, Any], **kwargs): - # We need to + # We need to inputs = [] i = 0 while True: @@ -689,8 +715,9 @@ def function(ufunc, method: str, input_kwargs: Dict[str, Any], **kwargs): break inputs.append(kwargs.pop(key)) i += 1 - return getattr(ufunc, method)(*inputs, **input_kwargs) - + return getattr(ufunc, method)(*inputs, **input_kwargs) + + class ConstantNode(Node): """Node that just returns its input value.""" diff --git a/src/sisl/nodes/syntax_nodes.py b/src/sisl/nodes/syntax_nodes.py index bcf4719a9d..7e5e347e69 100644 --- a/src/sisl/nodes/syntax_nodes.py +++ b/src/sisl/nodes/syntax_nodes.py @@ -4,21 +4,20 @@ class SyntaxNode(Node): ... + class ListSyntaxNode(SyntaxNode): - @staticmethod def function(*items): return list(items) - + class TupleSyntaxNode(SyntaxNode): - @staticmethod def function(*items): return tuple(items) - + + class DictSyntaxNode(SyntaxNode): - @staticmethod def function(**items): - return items \ No newline at end of file + return items diff --git a/src/sisl/nodes/tests/test_context.py b/src/sisl/nodes/tests/test_context.py index 79d06d7345..a9ee16f097 100644 --- a/src/sisl/nodes/tests/test_context.py +++ b/src/sisl/nodes/tests/test_context.py @@ -5,7 +5,6 @@ def test_node(): - @Node.from_func def sum_node(a, b): return a + b @@ -19,11 +18,13 @@ def sum_node(a, b): val = sum_node(a=2, b=3) assert val == 5 + def test_node_inside_node(): """When a node class is called inside another node, it should never be lazy in its computation. - + That is, calling a node within another node is like calling a function. """ + @Node.from_func def shift(a): return a + 1 @@ -42,9 +43,9 @@ def sum_node(a, b): val = sum_node(a=2, b=3) assert val == 6 + @pytest.mark.parametrize("nodes_lazy", [True, False]) def test_workflow(nodes_lazy): - def sum_node(a, b): return a + b @@ -52,26 +53,26 @@ def sum_node(a, b): def my_workflow(a, b, c): first_sum = sum_node(a, b) return sum_node(first_sum, c) - + with temporal_context(context=Node.context, lazy=nodes_lazy): - #It shouldn't matter whether nodes have lazy computation on or off for the working of the workflow + # It shouldn't matter whether nodes have lazy computation on or off for the working of the workflow with temporal_context(context=Workflow.context, lazy=True): val = my_workflow(a=2, b=3, c=4) assert isinstance(val, my_workflow) assert val.get() == 9 - + with temporal_context(context=Workflow.context, lazy=False): val = my_workflow(a=2, b=3, c=4) assert val == 9 -def test_instance_context(): +def test_instance_context(): @Node.from_func def sum_node(a, b): return a + b - + sum_node.context.update(lazy=True) - + # By default, an instance should behave as the class context specifies, # so in this case the node should not automatically recalculate val = sum_node(a=2, b=3) @@ -103,7 +104,7 @@ def test_default_context(lazy_init): @Node.from_func def calc(val: int): return val - + @Node.from_func(context={"lazy": False, "lazy_init": lazy_init}) def alert_change(val: int): ... @@ -111,11 +112,11 @@ def alert_change(val: int): val = calc(1) init_nupdates = 0 if lazy_init else 1 - - # We feed the node that produces the intermediate value into our alert node + + # We feed the node that produces the intermediate value into our alert node my_alert = alert_change(val=val) val.get() assert my_alert._nupdates == init_nupdates val.update_inputs(val=2) - assert my_alert._nupdates == init_nupdates + 1 \ No newline at end of file + assert my_alert._nupdates == init_nupdates + 1 diff --git a/src/sisl/nodes/tests/test_node.py b/src/sisl/nodes/tests/test_node.py index a45bf3680c..c01262c107 100644 --- a/src/sisl/nodes/tests/test_node.py +++ b/src/sisl/nodes/tests/test_node.py @@ -4,20 +4,24 @@ from sisl.nodes.node import GetItemNode -@pytest.fixture(scope='module', params=["explicit_class", "from_func"]) +@pytest.fixture(scope="module", params=["explicit_class", "from_func"]) def sum_node(request): if request.param == "explicit_class": + class SumNode(Node): @staticmethod def function(input1, input2): return input1 + input2 + else: + @Node.from_func def SumNode(input1, input2): return input1 + input2 return SumNode + def test_node_classes_reused(): def a(): pass @@ -27,22 +31,25 @@ def a(): assert x is y + def test_node_runs(sum_node): - node = sum_node(1,2) + node = sum_node(1, 2) res = node.get() assert res == 3 + @temporal_context(lazy=False) def test_nonlazy_node(sum_node): - node = sum_node(1,2) - + node = sum_node(1, 2) + assert node._nupdates == 1 assert node._output == 3 + @temporal_context(lazy=True) def test_node_not_updated(sum_node): """Checks that the node only runs when it needs to.""" - node = sum_node(1,2) + node = sum_node(1, 2) assert node._nupdates == 0 @@ -54,9 +61,9 @@ def test_node_not_updated(sum_node): assert res == 3 assert node._nupdates == 1 + @temporal_context(lazy=True) def test_node_links(): - @Node.from_func def my_node(a: int = 2): return a @@ -71,8 +78,8 @@ def my_node(a: int = 2): # And that node2 knows it's using node1 as an input. assert len(node2._input_nodes) == 1 - assert 'a' in node2._input_nodes - assert node2._input_nodes['a'] is node1 + assert "a" in node2._input_nodes + assert node2._input_nodes["a"] is node1 # Now check that if we update node2, the connections # will be removed. @@ -93,12 +100,13 @@ def my_node(a: int = 2): # And that node2 knows it's using node3 as an input. assert len(node2._input_nodes) == 1 - assert 'a' in node2._input_nodes - assert node2._input_nodes['a'] is node3 + assert "a" in node2._input_nodes + assert node2._input_nodes["a"] is node3 + @temporal_context(lazy=True) def test_node_tree(sum_node): - node1 = sum_node(1,2) + node1 = sum_node(1, 2) node2 = sum_node(node1, 3) res = node2.get() @@ -117,8 +125,8 @@ def test_node_tree(sum_node): assert res == 15 assert node1._nupdates == 2 -def test_automatic_recalculation(sum_node): +def test_automatic_recalculation(sum_node): # Set the first node automatic recalculation on node1 = sum_node(1, 2) assert node1._nupdates == 0 @@ -152,29 +160,29 @@ def test_automatic_recalculation(sum_node): assert node1._output == 5 assert node2._output == 8 -def test_getitem(): +def test_getitem(): @Node.from_func def some_tuple(): return (3, 4) - + my_tuple = some_tuple() val = my_tuple.get()[0] assert val == 3 item = my_tuple[0] - + assert isinstance(item, GetItemNode) assert item.get() == 3 + @temporal_context(lazy=True) def test_args(): """Checks that functions with *args are correctly handled by Node.""" @Node.from_func def reduce_(*nums, factor: int = 1): - val = 0 for num in nums: val += num @@ -189,9 +197,9 @@ def reduce_(*nums, factor: int = 1): assert val2.get() == 16 assert val._nupdates == 1 + @temporal_context(lazy=True) def test_node_links_args(): - @Node.from_func def my_node(*some_args): return some_args @@ -205,8 +213,8 @@ def my_node(*some_args): # And that node2 knows it's using node1 as an input. assert len(node2._input_nodes) == 1 - assert 'some_args[2]' in node2._input_nodes - assert node2._input_nodes['some_args[2]'] is node1 + assert "some_args[2]" in node2._input_nodes + assert node2._input_nodes["some_args[2]"] is node1 @temporal_context(lazy=True) @@ -225,6 +233,7 @@ def my_dict(**some_kwargs): assert val2.get() == {"old": {"a": 2, "b": 4}} + @temporal_context(lazy=True) def test_update_kwargs(): """Checks that functions with **kwargs are correctly handled by Node.""" @@ -245,9 +254,9 @@ def my_dict(**some_kwargs): val.update_inputs(a=Node.DELETE_KWARG) assert val.get() == {"b": 4, "c": 5} + @temporal_context(lazy=True) def test_node_links_kwargs(): - @Node.from_func def my_node(**some_kwargs): return some_kwargs @@ -262,8 +271,8 @@ def my_node(**some_kwargs): # And that node2 knows it's using node1 as an input. assert len(node2._input_nodes) == 1 - assert 'some_kwargs[a]' in node2._input_nodes - assert node2._input_nodes['some_kwargs[a]'] is node1 + assert "some_kwargs[a]" in node2._input_nodes + assert node2._input_nodes["some_kwargs[a]"] is node1 # Test that kwargs that no longer exist are delinked. @@ -286,15 +295,15 @@ def my_node(**some_kwargs): # And that node2 knows it's using node3 as an input. assert len(node2._input_nodes) == 1 - assert 'some_kwargs[a]' in node2._input_nodes - assert node2._input_nodes['some_kwargs[a]'] is node3 + assert "some_kwargs[a]" in node2._input_nodes + assert node2._input_nodes["some_kwargs[a]"] is node3 -def test_ufunc(sum_node): +def test_ufunc(sum_node): node = sum_node(1, 3) assert node.get() == 4 node2 = node + 6 - assert node2.get() == 10 \ No newline at end of file + assert node2.get() == 10 diff --git a/src/sisl/nodes/tests/test_syntax_nodes.py b/src/sisl/nodes/tests/test_syntax_nodes.py index 6d607bc62c..6267491f8f 100644 --- a/src/sisl/nodes/tests/test_syntax_nodes.py +++ b/src/sisl/nodes/tests/test_syntax_nodes.py @@ -5,25 +5,27 @@ def test_list_syntax_node(): assert ListSyntaxNode("a", "b", "c").get() == ["a", "b", "c"] + def test_tuple_syntax_node(): assert TupleSyntaxNode("a", "b", "c").get() == ("a", "b", "c") + def test_dict_syntax_node(): assert DictSyntaxNode(a="b", c="d", e="f").get() == {"a": "b", "c": "d", "e": "f"} -def test_workflow_with_syntax(): +def test_workflow_with_syntax(): def f(a): return [a] - + assert Workflow.from_func(f)(2).get() == [2] def f(a): return (a,) - + assert Workflow.from_func(f)(2).get() == (2,) def f(a): return {"a": a} - + assert Workflow.from_func(f)(2).get() == {"a": 2} diff --git a/src/sisl/nodes/tests/test_utils.py b/src/sisl/nodes/tests/test_utils.py index 0e12873727..9833702496 100644 --- a/src/sisl/nodes/tests/test_utils.py +++ b/src/sisl/nodes/tests/test_utils.py @@ -9,104 +9,106 @@ ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def sum_node(): - @Node.from_func def sum(a, b): return a + b - + return sum - -def test_traverse_tree_forward(sum_node): + +def test_traverse_tree_forward(sum_node): initial = sum_node(0, 1) second = sum_node(initial, 2) final = sum_node(second, 3) i = 0 + def count(node): nonlocal i i += 1 - traverse_tree_forward((final, ), func=count) + traverse_tree_forward((final,), func=count) assert i == 1 i = 0 - traverse_tree_forward((second, ), func=count) + traverse_tree_forward((second,), func=count) assert i == 2 i = 0 - traverse_tree_forward((initial, ), func=count) + traverse_tree_forward((initial,), func=count) assert i == 3 def only_first(node): nonlocal i i += 1 raise StopTraverse - + i = 0 - traverse_tree_forward((initial, ), func=only_first) + traverse_tree_forward((initial,), func=only_first) assert i == 1 -def test_traverse_tree_backward(sum_node): +def test_traverse_tree_backward(sum_node): initial = sum_node(0, 1) second = sum_node(initial, 2) final = sum_node(second, 3) i = 0 + def count(node): nonlocal i i += 1 - traverse_tree_backward((final, ), func=count) + traverse_tree_backward((final,), func=count) assert i == 3 i = 0 - traverse_tree_backward((second, ), func=count) + traverse_tree_backward((second,), func=count) assert i == 2 i = 0 - traverse_tree_backward((initial, ), func=count) + traverse_tree_backward((initial,), func=count) assert i == 1 def only_first(node): nonlocal i i += 1 raise StopTraverse - + i = 0 - traverse_tree_backward((final, ), func=only_first) + traverse_tree_backward((final,), func=only_first) assert i == 1 -def test_visit_all_connected(sum_node): +def test_visit_all_connected(sum_node): initial = sum_node(0, 1) second = sum_node(initial, 2) final = sum_node(second, 3) i = 0 + def count(node): nonlocal i i += 1 - visit_all_connected((initial, ), func=count) + visit_all_connected((initial,), func=count) assert i == 3 i = 0 - visit_all_connected((second, ), func=count) + visit_all_connected((second,), func=count) assert i == 3 i = 0 - visit_all_connected((final, ), func=count) + visit_all_connected((final,), func=count) assert i == 3 def only_first(node): nonlocal i i += 1 raise StopTraverse - + i = 0 - visit_all_connected((final, ), func=only_first) - assert i == 1 \ No newline at end of file + visit_all_connected((final,), func=only_first) + assert i == 1 diff --git a/src/sisl/nodes/tests/test_workflow.py b/src/sisl/nodes/tests/test_workflow.py index 1b05a1ad1b..75c4fa5418 100644 --- a/src/sisl/nodes/tests/test_workflow.py +++ b/src/sisl/nodes/tests/test_workflow.py @@ -6,35 +6,38 @@ from sisl.nodes.utils import traverse_tree_forward -@pytest.fixture(scope='module', params=["from_func", "explicit_class", "input_operations"]) +@pytest.fixture( + scope="module", params=["from_func", "explicit_class", "input_operations"] +) def triple_sum(request) -> Type[Workflow]: """Returns a workflow that computes a triple sum. - + The workflow might have been obtained in different ways, but they all should be equivalent in functionality. """ def my_sum(a, b): return a + b - + if request.param == "from_func": - # A triple sum + # A triple sum @Workflow.from_func def triple_sum(a, b, c): first_sum = my_sum(a, b) return my_sum(first_sum, c) - + triple_sum._sum_key = "my_sum" elif request.param == "explicit_class": - class triple_sum(Workflow): + class triple_sum(Workflow): @staticmethod def function(a, b, c): first_sum = my_sum(a, b) return my_sum(first_sum, c) - + triple_sum._sum_key = "my_sum" elif request.param == "input_operations": + @Workflow.from_func def triple_sum(a, b, c): first_sum = a + b @@ -44,19 +47,24 @@ def triple_sum(a, b, c): return triple_sum + def test_named_vars(triple_sum): # Check that the first_sum variable has been detected correctly. - assert set(triple_sum.dryrun_nodes.named_vars) == {'first_sum'} + assert set(triple_sum.dryrun_nodes.named_vars) == {"first_sum"} # And check that it maps to the correct node. - assert triple_sum.dryrun_nodes.first_sum is triple_sum.dryrun_nodes.workers[triple_sum._sum_key] + assert ( + triple_sum.dryrun_nodes.first_sum + is triple_sum.dryrun_nodes.workers[triple_sum._sum_key] + ) + def test_workflow_instantiation(triple_sum): # Create an instance of the workflow. flow = triple_sum(2, 3, 5) # Check that the workflow nodes have been instantiated. - assert hasattr(flow, 'nodes') + assert hasattr(flow, "nodes") for k, wf_node in flow.dryrun_nodes.items(): assert k in flow.nodes._all_nodes new_node = flow.nodes[k] @@ -71,23 +79,24 @@ def check_not_old_id(node): traverse_tree_forward(flow.nodes.inputs.values(), check_not_old_id) + def test_right_result(triple_sum): assert triple_sum(a=2, b=3, c=5).get() == 10 -def test_updatable_inputs(triple_sum): +def test_updatable_inputs(triple_sum): val = triple_sum(a=2, b=3, c=5) - + assert val.get() == 10 val.update_inputs(b=4) assert val.get() == 11 -def test_recalc_necessary_only(triple_sum): +def test_recalc_necessary_only(triple_sum): val = triple_sum(a=2, b=3, c=5) - + assert val.get() == 10 val.update_inputs(c=4) @@ -95,15 +104,16 @@ def test_recalc_necessary_only(triple_sum): assert val.get() == 9 assert val.nodes[triple_sum._sum_key]._nupdates == 1 - assert val.nodes[f'{triple_sum._sum_key}_1']._nupdates == 2 + assert val.nodes[f"{triple_sum._sum_key}_1"]._nupdates == 2 -def test_positional_arguments(triple_sum): +def test_positional_arguments(triple_sum): val = triple_sum(2, 3, 5) assert val.get() == 10 -# *args and **kwargs are not supported for now in workflows. + +# *args and **kwargs are not supported for now in workflows. # def test_kwargs_not_overriden(): # @Node.from_func @@ -113,15 +123,15 @@ def test_positional_arguments(triple_sum): # @Workflow.from_func # def some_workflow(**kwargs): # return some_node(a=2, b=3, **kwargs) - + # # Here we check that passing **kwargs to the node inside the workflow # # does not interfere with the other keyword arguments that are explicitly # # passed to the node (and accepted by the node as **kwargs) # assert some_workflow().get() == {'a': 2, 'b': 3} # assert some_workflow(c=4).get() == {'a': 2, 'b': 3, 'c': 4} -def test_args_nodes_registered(): +def test_args_nodes_registered(): def some_node(*args): return args @@ -129,13 +139,13 @@ def some_node(*args): def some_workflow(): a = some_node(1, 2, 3) return some_node(2, a, 4) - + # Check that the workflow knows about the first instanced node. wf = some_workflow() assert len(wf.nodes.workers) == 2 -def test_kwargs_nodes_registered(): +def test_kwargs_nodes_registered(): def some_node(**kwargs): return kwargs @@ -143,36 +153,36 @@ def some_node(**kwargs): def some_workflow(): a = some_node(a=1, b=2, c=3) return some_node(b=2, a=a, c=4) - + # Check that the workflow knows about the first instanced node. wf = some_workflow() assert len(wf.nodes.workers) == 2 -def test_workflow_inside_workflow(triple_sum): +def test_workflow_inside_workflow(triple_sum): def multiply(a, b): return a * b @Workflow.from_func def some_multiplication(a, b, c, d, e, f): - """ Workflow that computes (a + b + c) * (d + e + f)""" - return multiply(triple_sum(a,b,c), triple_sum(d, e, f)) + """Workflow that computes (a + b + c) * (d + e + f)""" + return multiply(triple_sum(a, b, c), triple_sum(d, e, f)) val = some_multiplication(1, 2, 3, 1, 2, 1) assert val.get() == (1 + 2 + 3) * (1 + 2 + 1) - first_triple_sum = val.nodes['triple_sum'] + first_triple_sum = val.nodes["triple_sum"] assert first_triple_sum.nodes[triple_sum._sum_key]._nupdates == 1 - assert first_triple_sum.nodes[f'{triple_sum._sum_key}_1']._nupdates == 1 + assert first_triple_sum.nodes[f"{triple_sum._sum_key}_1"]._nupdates == 1 val.update_inputs(c=2) assert first_triple_sum.nodes[triple_sum._sum_key]._nupdates == 1 - assert first_triple_sum.nodes[f'{triple_sum._sum_key}_1']._nupdates == 1 + assert first_triple_sum.nodes[f"{triple_sum._sum_key}_1"]._nupdates == 1 assert val.get() == (1 + 2 + 2) * (1 + 2 + 1) assert first_triple_sum.nodes[triple_sum._sum_key]._nupdates == 1 - assert first_triple_sum.nodes[f'{triple_sum._sum_key}_1']._nupdates == 2 \ No newline at end of file + assert first_triple_sum.nodes[f"{triple_sum._sum_key}_1"]._nupdates == 2 diff --git a/src/sisl/nodes/utils.py b/src/sisl/nodes/utils.py index 27e61bfb01..f5215fcba7 100644 --- a/src/sisl/nodes/utils.py +++ b/src/sisl/nodes/utils.py @@ -5,10 +5,11 @@ from .node import Node -class StopTraverse(Exception): +class StopTraverse(Exception): """Exception that should be raised by callback functions to stop the traversal of a tree.""" -def traverse_tree_forward(roots: Sequence[Node], func: Callable[[Node], Any]) -> None: + +def traverse_tree_forward(roots: Sequence[Node], func: Callable[[Node], Any]) -> None: """Traverse a tree of nodes in a forward fashion. Parameters @@ -26,7 +27,8 @@ def traverse_tree_forward(roots: Sequence[Node], func: Callable[[Node], Any]) - continue traverse_tree_forward(root._output_links, func) -def traverse_tree_backward(leaves: Sequence[Node], func: Callable[[Node], Any]) -> None: + +def traverse_tree_backward(leaves: Sequence[Node], func: Callable[[Node], Any]) -> None: """Traverse a tree of nodes in a backwards fashion. Parameters @@ -43,12 +45,15 @@ def traverse_tree_backward(leaves: Sequence[Node], func: Callable[[Node], Any]) except StopTraverse: continue leaf.map_inputs( - leaf.inputs, - func=lambda node: traverse_tree_backward((node, ), func=func), - only_nodes=True + leaf.inputs, + func=lambda node: traverse_tree_backward((node,), func=func), + only_nodes=True, ) -def visit_all_connected(nodes: Sequence[Node], func: Callable[[Node], Any], _seen_nodes=None) -> None: + +def visit_all_connected( + nodes: Sequence[Node], func: Callable[[Node], Any], _seen_nodes=None +) -> None: """Visit all nodes that are connected to a list of nodes. Parameters @@ -67,21 +72,22 @@ def visit_all_connected(nodes: Sequence[Node], func: Callable[[Node], Any], _see continue _seen_nodes.append(id(node)) - + try: func(node) except StopTraverse: continue - + def visit(visited_node): if visited_node is node: return - - visit_all_connected((visited_node, ), func=func, _seen_nodes=_seen_nodes) + + visit_all_connected((visited_node,), func=func, _seen_nodes=_seen_nodes) raise StopTraverse - - traverse_tree_forward((node, ), func=visit ) - traverse_tree_backward((node, ), func=visit) + + traverse_tree_forward((node,), func=visit) + traverse_tree_backward((node,), func=visit) + def nodify_module(module: ModuleType, node_class: Type[Node] = Node) -> ModuleType: """Returns a copy of a module where all functions are replaced with nodes. @@ -107,13 +113,15 @@ def nodify_module(module: ModuleType, node_class: Type[Node] = Node) -> ModuleTy ModuleType A new module with all functions replaced with nodes. """ - + # Function that recursively traverses the module and replaces functions with nodes. - def _nodified_module(module: ModuleType, visited: Dict[ModuleType, ModuleType], main_module: str) -> ModuleType: + def _nodified_module( + module: ModuleType, visited: Dict[ModuleType, ModuleType], main_module: str + ) -> ModuleType: # This module has already been visited, so do return the already nodified module. if module in visited: return visited[module] - + # Create a copy of this module, with the nodified_ prefix in the name. noded_module = ModuleType(f"nodified_{module.__name__}") # Register the module as visited. @@ -134,7 +142,9 @@ def _nodified_module(module: ModuleType, visited: Dict[ModuleType, ModuleType], # skip it. This is to avoid nodifying variables that were imported # from other modules. module_name = getattr(variable, "__module__", "") or "" - if not (isinstance(module_name, str) and module_name.startswith(main_module)): + if not ( + isinstance(module_name, str) and module_name.startswith(main_module) + ): continue # If the variable is a function or a class, try to create a node from it. @@ -147,16 +157,20 @@ def _nodified_module(module: ModuleType, visited: Dict[ModuleType, ModuleType], ... elif inspect.ismodule(variable): module_name = getattr(variable, "__name__", "") or "" - if not (isinstance(module_name, str) and module_name.startswith(main_module)): + if not ( + isinstance(module_name, str) and module_name.startswith(main_module) + ): continue # If the variable is a module, recursively nodify it. - noded_variable = _nodified_module(variable, visited, main_module=main_module) - + noded_variable = _nodified_module( + variable, visited, main_module=main_module + ) + # Add the new noded variable to the new module. if noded_variable is not None: setattr(noded_module, k, noded_variable) return noded_module - - return _nodified_module(module, visited={}, main_module=module.__name__) \ No newline at end of file + + return _nodified_module(module, visited={}, main_module=module.__name__) diff --git a/src/sisl/nodes/workflow.py b/src/sisl/nodes/workflow.py index 77d7353544..7780276e27 100644 --- a/src/sisl/nodes/workflow.py +++ b/src/sisl/nodes/workflow.py @@ -18,26 +18,28 @@ from .utils import traverse_tree_backward, traverse_tree_forward register_environ_variable( - "SISL_NODES_EXPORT_VIS", default=False, - description="Whether the visualizations of the networks in notebooks are meant to be exported.", + "SISL_NODES_EXPORT_VIS", + default=False, + description="Whether the visualizations of the networks in notebooks are meant to be exported.", ) + class WorkflowInput(DummyInputValue): pass + class WorkflowOutput(Node): - @staticmethod def function(value: Any) -> Any: return value -class NetworkDescriptor: +class NetworkDescriptor: def __get__(self, instance, owner): return Network(owner) -class Network: +class Network: _workflow: Type[Workflow] def __init__(self, workflow: Type[Workflow]): @@ -45,10 +47,12 @@ def __init__(self, workflow: Type[Workflow]): @staticmethod def _get_edges( - workflow: Type[Workflow], include_workflow_inputs: bool = False, edge_labels: bool = True + workflow: Type[Workflow], + include_workflow_inputs: bool = False, + edge_labels: bool = True, ) -> List[Tuple[str, str, dict]]: """Get the edges that connect nodes in the workflow. - + Parameters ---------- workflow: Type[Workflow] @@ -69,14 +73,16 @@ def _get_edges( """ edges = [] - # Build the function that given two connected nodes will return the + # Build the function that given two connected nodes will return the # edge metadata or "props". def _edge_props(node_out, node_in, key) -> dict: props = {} if edge_labels: - props['label'] = key - - props['title'] = f"{node_out.__class__.__name__}() -> {node_in.__class__.__name__}.{key}" + props["label"] = key + + props[ + "title" + ] = f"{node_out.__class__.__name__}() -> {node_in.__class__.__name__}.{key}" return props # Get the workflow's nodes @@ -90,26 +96,33 @@ def _edge_props(node_out, node_in, key) -> dict: # Loop over the inputs of this node that contain other nodes and # add the edges that connect them to this node. for other_key, input_node in node._input_nodes.items(): - if not include_workflow_inputs and isinstance(input_node, WorkflowInput): + if not include_workflow_inputs and isinstance( + input_node, WorkflowInput + ): continue - edges.append(( - workflow.find_node_key(input_node), - node_key, - _edge_props(input_node, node, other_key) - )) - + edges.append( + ( + workflow.find_node_key(input_node), + node_key, + _edge_props(input_node, node, other_key), + ) + ) + return edges - def to_nx(self, workflow: Type[Workflow], include_workflow_inputs: bool = False, - edge_labels: bool = False + def to_nx( + self, + workflow: Type[Workflow], + include_workflow_inputs: bool = False, + edge_labels: bool = False, ) -> "nx.DiGraph": """Convert a Workflow class to a networkx directed graph. The nodes of the graph are the node functions that compose the workflow. The edges represent connections between nodes, where one node's output is sent to another node's input. - + Parameters ---------- workflow: Type[Workflow] @@ -123,30 +136,44 @@ def to_nx(self, workflow: Type[Workflow], include_workflow_inputs: bool = False, import networkx as nx # Get the edges - edges = self._get_edges(workflow, include_workflow_inputs=include_workflow_inputs, edge_labels=edge_labels) - + edges = self._get_edges( + workflow, + include_workflow_inputs=include_workflow_inputs, + edge_labels=edge_labels, + ) + # And build the networkx directed graph graph = nx.DiGraph() - graph.add_edges_from(edges, ) + graph.add_edges_from( + edges, + ) for name, node_key in workflow.dryrun_nodes.named_vars.items(): - graph.nodes[node_key]['label'] = name - + graph.nodes[node_key]["label"] = name + return graph - def to_pyvis(self, colorscale: str = "viridis", show_workflow_inputs: bool = False, - edge_labels: bool = True, node_help: bool = True, - notebook: bool = False, hierarchial: bool = True, inputs_props: Dict[str, Any] = {}, node_props: Dict[str, Any] = {}, - leafs_props: Dict[str, Any] = {}, output_props: Dict[str, Any] = {}, + def to_pyvis( + self, + colorscale: str = "viridis", + show_workflow_inputs: bool = False, + edge_labels: bool = True, + node_help: bool = True, + notebook: bool = False, + hierarchial: bool = True, + inputs_props: Dict[str, Any] = {}, + node_props: Dict[str, Any] = {}, + leafs_props: Dict[str, Any] = {}, + output_props: Dict[str, Any] = {}, auto_text_color: bool = True, to_export: Union[bool, None] = None, - ): + ): """Convert a Workflow class to a pyvis network for visualization. The nodes of the graph are the node functions that compose the workflow. The edges represent connections between nodes, where one node's output is sent to another node's input. - + Parameters ---------- colorscale: str, optional @@ -187,13 +214,19 @@ def to_pyvis(self, colorscale: str = "viridis", show_workflow_inputs: bool = Fal import networkx as nx from pyvis.network import Network as visNetwork except ModuleNotFoundError: - raise ModuleNotFoundError("You need to install the 'networkx', 'pyvis' and 'matplotlib' packages to visualize workflows.") + raise ModuleNotFoundError( + "You need to install the 'networkx', 'pyvis' and 'matplotlib' packages to visualize workflows." + ) if to_export is None: to_export = get_environ_variable("SISL_NODES_EXPORT_VIS") != False - + # Get the networkx directed graph - graph = self.to_nx(self._workflow, include_workflow_inputs=show_workflow_inputs, edge_labels=edge_labels) + graph = self.to_nx( + self._workflow, + include_workflow_inputs=show_workflow_inputs, + edge_labels=edge_labels, + ) # Find out the generations of nodes (i.e. how many nodes are there until a beginning of the graph) topo_gens = list(nx.topological_generations(graph)) @@ -201,10 +234,14 @@ def to_pyvis(self, colorscale: str = "viridis", show_workflow_inputs: bool = Fal # and whether the last generation is just the output node. # This will help us avoid these generations when coloring on generation, # allowing for more color range. - inputs_are_first_gen = all(node in self._workflow.dryrun_nodes.inputs for node in topo_gens[0]) + inputs_are_first_gen = all( + node in self._workflow.dryrun_nodes.inputs for node in topo_gens[0] + ) output_node_key = "output" - output_is_last_gen = len(topo_gens[-1]) == 1 and output_node_key == topo_gens[-1][0] + output_is_last_gen = ( + len(topo_gens[-1]) == 1 and output_node_key == topo_gens[-1][0] + ) # Then determine the range min_gen = 0 @@ -221,12 +258,18 @@ def _get_node_help(node: Union[Node, Type[Node]]): _get_node_inputs_str(node) else: node_cls = node - short_doc = (node_cls.__doc__ or "").lstrip().split('\n')[0] or "No documentation" - sig = "\n".join(str(param) for param in node_cls.__signature__.parameters.values()) + short_doc = (node_cls.__doc__ or "").lstrip().split("\n")[ + 0 + ] or "No documentation" + sig = "\n".join( + str(param) for param in node_cls.__signature__.parameters.values() + ) return f"Node class: {node_cls.__name__}\n{short_doc}\n............................................\n{sig}" def _get_node_inputs_str(node: Node) -> str: - short_doc = (node.__class__.__doc__ or "").lstrip().split('\n')[0] or "No documentation" + short_doc = (node.__class__.__doc__ or "").lstrip().split("\n")[ + 0 + ] or "No documentation" node_inputs_str = f"Node class: {node.__class__.__name__}\n{short_doc}\n............................................\n" @@ -255,11 +298,12 @@ def _render_value(v): # in matplotlib > 3.5 mpl.colormaps[colorscale] # This is portable cmap = plt.get_cmap(colorscale) + def rgb2gray(rgb): return rgb[0] * 0.2989 + rgb[1] * 0.5870 + rgb[2] * 0.1140 + # Loop through generations of nodes. for i, nodes in enumerate(topo_gens): - # Get the level, color and shape for this generation level = float(i + 1) @@ -267,18 +311,24 @@ def rgb2gray(rgb): color_value = (i - min_gen) / color_range else: color_value = 0 - + # Get the color for this generation rgb = cmap(color_value) color = mpl.colors.rgb2hex(rgb) - + shape = node_props.get("shape", "circle") # If automatic text coloring is requested, set the font color to white if the background is dark. # But only if the shape is an ellipse, circle, database, box or text, as these are the only ones # that have the text inside. font = {} - if auto_text_color and shape in ["ellipse", "circle", "database", "box", "text"]: + if auto_text_color and shape in [ + "ellipse", + "circle", + "database", + "box", + "text", + ]: gray = rgb2gray(rgb) if gray <= 0.5: font = {"color": "white"} @@ -289,28 +339,41 @@ def rgb2gray(rgb): if node_help: node_obj = self._workflow.dryrun_nodes.get(node) - title = _get_node_inputs_str(node_obj) if node_obj is not None else "" + title = ( + _get_node_inputs_str(node_obj) if node_obj is not None else "" + ) else: title = "" - graph_node.update({ - "mass": 2, "shape": shape, "color": color, "level": level, - "title": title, "font": font, - **node_props - }) + graph_node.update( + { + "mass": 2, + "shape": shape, + "color": color, + "level": level, + "title": title, + "font": font, + **node_props, + } + ) # Set the props of leaf nodes (those that have their output linked to nothing) leaves = [k for k, v in graph.out_degree if v == 0] for leaf in leaves: - graph.nodes[leaf].update({ - "shape": "square", "font": {"color": "black"}, **leafs_props - }) - + graph.nodes[leaf].update( + {"shape": "square", "font": {"color": "black"}, **leafs_props} + ) + # Set the props of the output node - graph.nodes[output_node_key].update({ - "color": "pink", "shape": "box", "label": "Output", - "font": {"color": "black"}, **output_props - }) + graph.nodes[output_node_key].update( + { + "color": "pink", + "shape": "box", + "label": "Output", + "font": {"color": "black"}, + **output_props, + } + ) if show_workflow_inputs: for k, node in self._workflow.dryrun_nodes.inputs.items(): @@ -321,25 +384,32 @@ def rgb2gray(rgb): if node._output_links: min_level = min( - graph.nodes[self._workflow.find_node_key(node)].get("level", 0) + graph.nodes[self._workflow.find_node_key(node)].get("level", 0) for node in node._output_links ) else: min_level = 2 - graph.nodes[k].update({ - "color": "#ccc", "font": {"color": "black"}, "shape": "box", - "label": f"Input({k})", "level": min_level - 1, - **inputs_props - }) - + graph.nodes[k].update( + { + "color": "#ccc", + "font": {"color": "black"}, + "shape": "box", + "label": f"Input({k})", + "level": min_level - 1, + **inputs_props, + } + ) + layout = {} if hierarchial: layout["hierarchial"] = True - net = visNetwork(notebook=notebook, directed=True, layout=layout, cdn_resources="remote") + net = visNetwork( + notebook=notebook, directed=True, layout=layout, cdn_resources="remote" + ) net.from_nx(graph) - + net.toggle_physics(True) net.toggle_stabilization(True) return net @@ -367,7 +437,7 @@ def _show_pyvis(net: "pyvis.Network", notebook: bool, to_export: Union[bool, Non if notebook: # Render the HTML in the notebook. - + from IPython.display import HTML, display # First option was to display as an HTML object @@ -375,15 +445,20 @@ def _show_pyvis(net: "pyvis.Network", notebook: bool, to_export: Union[bool, Non # The wrapper div is needed because otherwise the HTML display is of height 0. # HOWEVER: IT DOESN'T DISPLAY when exported to HTML because the iframe that isolated=True # creates is removed. - #obj = HTML(f'
{html}
', metadata={"isolated": True}, ) + # obj = HTML(f'
{html}
', metadata={"isolated": True}, ) # Instead, we create an iframe ourselves, using the srcdoc attribute. The only thing that we need to worry # is that there are no double quotes in the html, otherwise the srcdoc attribute will be broken. # So we replace " by ", the html entity for ". if to_export: escaped_html = html.escape(html_text) - obj = HTML(f"""""") - else: - obj = HTML(f'
{html_text}
', metadata={"isolated": True}, ) + obj = HTML( + f"""""" + ) + else: + obj = HTML( + f'
{html_text}
', + metadata={"isolated": True}, + ) display(obj) else: @@ -396,13 +471,21 @@ def _show_pyvis(net: "pyvis.Network", notebook: bool, to_export: Union[bool, Non with tempfile.NamedTemporaryFile("w", suffix=".html", delete=False) as f: f.write(html_text) name = f.name - + webbrowser.open(name) - def visualize(self, colorscale: str = "viridis", show_workflow_inputs: bool = True, - edge_labels: bool = True, node_help: bool = True, - notebook: bool = False, hierarchial: bool = True, node_props: Dict[str, Any] = {}, - inputs_props: Dict[str, Any] = {}, leafs_props: Dict[str, Any] = {}, output_props: Dict[str, Any] = {}, + def visualize( + self, + colorscale: str = "viridis", + show_workflow_inputs: bool = True, + edge_labels: bool = True, + node_help: bool = True, + notebook: bool = False, + hierarchial: bool = True, + node_props: Dict[str, Any] = {}, + inputs_props: Dict[str, Any] = {}, + leafs_props: Dict[str, Any] = {}, + output_props: Dict[str, Any] = {}, to_export: Union[bool, None] = None, ): """Visualize the workflow's network in a plot. @@ -410,7 +493,7 @@ def visualize(self, colorscale: str = "viridis", show_workflow_inputs: bool = Tr The nodes of the graph are the node functions that compose the workflow. The edges represent connections between nodes, where one node's output is sent to another node's input. - + Parameters ---------- colorscale: str, optional @@ -448,23 +531,35 @@ def visualize(self, colorscale: str = "viridis", show_workflow_inputs: bool = Tr to_export = get_environ_variable("SISL_NODES_EXPORT_VIS") != False net = self.to_pyvis( - colorscale=colorscale, show_workflow_inputs=show_workflow_inputs, edge_labels=edge_labels, - node_help=node_help, notebook=notebook, - hierarchial=hierarchial, node_props=node_props, - inputs_props=inputs_props, leafs_props=leafs_props, output_props=output_props, + colorscale=colorscale, + show_workflow_inputs=show_workflow_inputs, + edge_labels=edge_labels, + node_help=node_help, + notebook=notebook, + hierarchial=hierarchial, + node_props=node_props, + inputs_props=inputs_props, + leafs_props=leafs_props, + output_props=output_props, to_export=to_export, ) return self._show_pyvis(net, notebook=notebook, to_export=to_export) -class WorkflowNodes: +class WorkflowNodes: inputs: Dict[str, WorkflowInput] workers: Dict[str, Node] output: WorkflowOutput named_vars: Dict[str, str] - def __init__(self, inputs: Dict[str, WorkflowInput], workers: Dict[str, Node], output: WorkflowOutput, named_vars: Dict[str, str]): + def __init__( + self, + inputs: Dict[str, WorkflowInput], + workers: Dict[str, Node], + output: WorkflowOutput, + named_vars: Dict[str, str], + ): self.inputs = inputs self.workers = workers self.output = output @@ -474,8 +569,12 @@ def __init__(self, inputs: Dict[str, WorkflowInput], workers: Dict[str, Node], o self._all_nodes = ChainMap(self.inputs, self.workers, {"output": self.output}) @classmethod - def from_workflow_run(cls, inputs: Dict[str, WorkflowInput], output: WorkflowOutput, named_vars: Dict[str, Node]): - + def from_workflow_run( + cls, + inputs: Dict[str, WorkflowInput], + output: WorkflowOutput, + named_vars: Dict[str, Node], + ): # Gather all worker nodes inside the workflow. workers = cls.gather_from_inputs_and_output(inputs.values(), output=output) @@ -487,14 +586,15 @@ def from_workflow_run(cls, inputs: Dict[str, WorkflowInput], output: WorkflowOut _named_vars[k] = node_key break - return cls(inputs=inputs, workers=workers, output=output, named_vars=_named_vars) - + return cls( + inputs=inputs, workers=workers, output=output, named_vars=_named_vars + ) + @classmethod def from_node_tree(cls, output_node): - # Gather all worker nodes inside the workflow. workers = cls.gather_from_inputs_and_output([], output=output_node) - + # Dictionary that will store the workflow input nodes. wf_inputs = {} # The workers found by traversing are node instances that might be in use @@ -508,49 +608,66 @@ def from_node_tree(cls, output_node): # Find out the inputs that we should connect to the workflow inputs. We connect all inputs # that are not nodes, and that are not the args or kwargs inputs. node_inputs = { - param_k: WorkflowInput(input_key=f"{node.__class__.__name__}_{param_k}", value=node.inputs[param_k]) - for param_k, v in node.inputs.items() if not ( - isinstance(v, Node) or param_k == node._args_inputs_key or param_k == node._kwargs_inputs_key + param_k: WorkflowInput( + input_key=f"{node.__class__.__name__}_{param_k}", + value=node.inputs[param_k], + ) + for param_k, v in node.inputs.items() + if not ( + isinstance(v, Node) + or param_k == node._args_inputs_key + or param_k == node._kwargs_inputs_key ) } # Create a new node using the newly determined inputs. However, we keep the links to the old nodes # These inputs will be updated later. with temporal_context(lazy=True): - new_workers[k] = node.__class__().update_inputs(**{**node.inputs, **node_inputs}) + new_workers[k] = node.__class__().update_inputs( + **{**node.inputs, **node_inputs} + ) # Register this new node in the mapping from old to new nodes. old_to_new[id(node)] = new_workers[k] - + # Update the workflow inputs dictionary with the inputs that we have determined # to be connected to this node. We use the node class name as a prefix to avoid # name clashes. THIS IS NOT PERFECT, IF THERE ARE TWO NODES OF THE SAME CLASS # THERE CAN BE A CLASH. - wf_inputs.update({ - f"{node.__class__.__name__}_{param_k}": v for param_k, v in node_inputs.items() - }) - + wf_inputs.update( + { + f"{node.__class__.__name__}_{param_k}": v + for param_k, v in node_inputs.items() + } + ) + # Now that we have all the node copies, update the links to old nodes with # links to new nodes. for k, node in new_workers.items(): - new_node_inputs = {} for param_k, v in node.inputs.items(): if param_k == node._args_inputs_key: - new_node_inputs[param_k] = [old_to_new[id(n)] if isinstance(n, Node) else n for n in v] + new_node_inputs[param_k] = [ + old_to_new[id(n)] if isinstance(n, Node) else n for n in v + ] elif param_k == node._args_inputs_key: - new_node_inputs[param_k] = {k: old_to_new[id(n)] if isinstance(n, Node) else n for k, n in v.items()} + new_node_inputs[param_k] = { + k: old_to_new[id(n)] if isinstance(n, Node) else n + for k, n in v.items() + } elif isinstance(v, Node) and not isinstance(v, WorkflowInput): new_node_inputs[param_k] = old_to_new[id(v)] - + with temporal_context(lazy=True): node.update_inputs(**new_node_inputs) # Create the workflow output. new_output = WorkflowOutput(value=old_to_new[id(output_node)]) - + # Initialize and return the WorkflowNodes object. - return cls(inputs=wf_inputs, workers=new_workers, output=new_output, named_vars={}) + return cls( + inputs=wf_inputs, workers=new_workers, output=new_output, named_vars={} + ) def __dir__(self) -> Iterable[str]: return dir(self.named_vars) + dir(self._all_nodes) @@ -580,7 +697,9 @@ def __str__(self): return f"Inputs: {self.inputs}\n\nWorkers: {self.workers}\n\nOutput: {self.output}\n\nNamed nodes: {self.named_vars}" @staticmethod - def gather_from_inputs_and_output(inputs: Sequence[WorkflowInput], output: WorkflowOutput) -> Dict[str, Node]: + def gather_from_inputs_and_output( + inputs: Sequence[WorkflowInput], output: WorkflowOutput + ) -> Dict[str, Node]: # Get a list of all nodes. nodes = [] @@ -593,7 +712,7 @@ def add_node(node): # Visit all nodes that depend on the inputs. traverse_tree_forward(inputs, add_node) - traverse_tree_backward((output, ), add_node) + traverse_tree_backward((output,), add_node) # Now build a dictionary that contains all nodes. dict_nodes = {} @@ -611,7 +730,7 @@ def add_node(node): # Add it to the dictionary of nodes dict_nodes[name] = node - + return dict_nodes def copy(self, inputs: Dict[str, Any] = {}) -> "WorkflowNodes": @@ -627,19 +746,18 @@ def copy(self, inputs: Dict[str, Any] = {}) -> "WorkflowNodes": new_inputs = {} for input_k, input_node in self.inputs.items(): new_inputs[input_k] = input_node.__class__( - input_key=input_node.input_key, value=inputs.get(input_k, input_node.value) + input_key=input_node.input_key, + value=inputs.get(input_k, input_node.value), ) # Now create a copy of the worker nodes - old_ids_to_key = { - id(node): key for key, node in self.workers.items() - } + old_ids_to_key = {id(node): key for key, node in self.workers.items()} new_output = [] new_workers = {} - + def copy_node(node): node_id = id(node) - + # If it is a WorkflowInput node, return the already created copy. if isinstance(node, WorkflowInput): return new_inputs[node.input_key] @@ -655,14 +773,13 @@ def copy_node(node): # This is an old node that we already copied, so we can just return the copy. if node_key in new_workers: return new_workers[node_key] - + # This is an old node that we haven't copied yet, so we need to copy it. # Get the new inputs, copying nodes if needed. new_node_inputs = node.map_inputs( - inputs=node.inputs, only_nodes=True, - func=copy_node + inputs=node.inputs, only_nodes=True, func=copy_node ) - + # Initialize a new node with the new inputs. args, kwargs = node._sanitize_inputs(new_node_inputs) new_node = node.__class__(*args, **kwargs) @@ -674,14 +791,20 @@ def copy_node(node): new_workers[node_key] = new_node return new_node - + with temporal_context(lazy=True): traverse_tree_forward(list(self.inputs.values()), copy_node) - traverse_tree_backward((self.output, ), copy_node) + traverse_tree_backward((self.output,), copy_node) assert len(new_output) == 1 and isinstance(new_output[0], WorkflowOutput) - - return self.__class__(inputs=new_inputs, workers=new_workers, output=new_output[0], named_vars=self.named_vars) + + return self.__class__( + inputs=new_inputs, + workers=new_workers, + output=new_output[0], + named_vars=self.named_vars, + ) + class Workflow(Node): # The nodes of the initial dry run. This is a class property @@ -696,10 +819,10 @@ class Workflow(Node): @classmethod def from_node_tree(cls, output_node: Node, workflow_name: Union[str, None] = None): """Creates a workflow class from a node. - + It does so by recursively traversing the tree in the inputs direction until it finds the leaves. - All the nodes found are included in the workflow. For each node, inputs + All the nodes found are included in the workflow. For each node, inputs that are not nodes are connected to the inputs of the workflow. Parameters @@ -717,23 +840,34 @@ def from_node_tree(cls, output_node: Node, workflow_name: Union[str, None] = Non """ # Create the node manager for the workflow. dryrun_nodes = WorkflowNodes.from_node_tree(output_node) - + # Create the signature of the workflow from the inputs that were determined # by the node manager. - signature = inspect.Signature(parameters=[ - inspect.Parameter(inp.input_key, inspect.Parameter.KEYWORD_ONLY, default=inp.value) for inp in dryrun_nodes.inputs.values() - ]) + signature = inspect.Signature( + parameters=[ + inspect.Parameter( + inp.input_key, inspect.Parameter.KEYWORD_ONLY, default=inp.value + ) + for inp in dryrun_nodes.inputs.values() + ] + ) def function(*args, **kwargs): - raise NotImplementedError("Workflow class created from node tree. Calling it as a function is not supported.") + raise NotImplementedError( + "Workflow class created from node tree. Calling it as a function is not supported." + ) function.__signature__ = signature # Create the class and return it. return type( - workflow_name or output_node.__class__.__name__, - (cls,), - {"dryrun_nodes": dryrun_nodes, "__signature__": signature, "function": staticmethod(function)} + workflow_name or output_node.__class__.__name__, + (cls,), + { + "dryrun_nodes": dryrun_nodes, + "__signature__": signature, + "function": staticmethod(function), + }, ) def setup(self, *args, **kwargs): @@ -743,7 +877,7 @@ def setup(self, *args, **kwargs): self.nodes = self.dryrun_nodes.copy(inputs=self._inputs) def __init_subclass__(cls): - # If this is just a subclass of Workflow that is not meant to be ran, continue + # If this is just a subclass of Workflow that is not meant to be ran, continue if not hasattr(cls, "function"): return super().__init_subclass__() # Also, if the node manager has already been created, continue. @@ -763,7 +897,9 @@ def assign_workflow_var(value: Any, var_name: str): repeats += 1 var_name = f"{original_name}_{repeats}" if var_name in named_vars: - raise ValueError(f"Variable {var_name} has already been assigned a value, in workflows you can't overwrite variables.") + raise ValueError( + f"Variable {var_name} has already been assigned a value, in workflows you can't overwrite variables." + ) named_vars[var_name] = value return value @@ -775,16 +911,20 @@ def assign_workflow_var(value: Any, var_name: str): # Get the signature of the function. sig = inspect.signature(work_func) - # Run a dryrun of the workflow, so that we can understand how the nodes are connected. + # Run a dryrun of the workflow, so that we can understand how the nodes are connected. # To this end, nodes must behave lazily. - with temporal_context(lazy=True): + with temporal_context(lazy=True): # Define all workflow inputs. inps = { k: WorkflowInput( - input_key=k, - value=param.default if param.default != inspect.Parameter.empty else Node._blank - ) for k, param in sig.parameters.items() - if param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + input_key=k, + value=param.default + if param.default != inspect.Parameter.empty + else Node._blank, + ) + for k, param in sig.parameters.items() + if param.kind + not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) } # Run the workflow function. @@ -792,9 +932,11 @@ def assign_workflow_var(value: Any, var_name: str): # Connect the final node to the output of the workflow. out = WorkflowOutput(value=final_node) - + # Store all the nodes of the workflow. - cls.dryrun_nodes = WorkflowNodes.from_workflow_run(inputs=inps, output=out, named_vars=named_vars) + cls.dryrun_nodes = WorkflowNodes.from_workflow_run( + inputs=inps, output=out, named_vars=named_vars + ) return super().__init_subclass__() @@ -804,7 +946,7 @@ def __dir__(self) -> Iterable[str]: @classmethod def final_node_key(cls, *args) -> str: """Returns the key of the final (output) node of the workflow.""" - return cls.find_node_key(cls.dryrun_nodes.output._inputs['value'], *args) + return cls.find_node_key(cls.dryrun_nodes.output._inputs["value"], *args) @classmethod def find_node_key(cls, node, *args) -> str: @@ -816,44 +958,53 @@ def find_node_key(cls, node, *args) -> str: if len(args) == 1: return args[0] - raise ValueError(f"Could not find node {node} in the workflow. Workflow nodes {cls.dryrun_nodes.items()}") + raise ValueError( + f"Could not find node {node} in the workflow. Workflow nodes {cls.dryrun_nodes.items()}" + ) def get(self): """Returns the up to date output of the workflow. - + It will recompute it if necessary. """ return self.nodes.output.get() - + def update_inputs(self, **inputs): """Updates the inputs of the workflow.""" - # Be careful here: - + # Be careful here: + # We should implement something that halts automatic recalculation. # Otherwise the nodes will be recalculated every time we update each # individual input. for input_key, value in inputs.items(): self.nodes.inputs[input_key].update_inputs(value=value) - + self._inputs.update(inputs) - + return self - + def _get_output(self): return self.nodes.output._output - + def _set_output(self, value): self.nodes.output._output = value - + _output = property(_get_output, _set_output) + class NodeConverter(ast.NodeTransformer): """AST transformer that converts a function into a workflow.""" - - def __init__(self, *args, assign_fn: Union[str, None] = None, node_cls_name: str = "Node", **kwargs): + + def __init__( + self, + *args, + assign_fn: Union[str, None] = None, + node_cls_name: str = "Node", + **kwargs, + ): super().__init__(*args, **kwargs) - + self.assign_fn = assign_fn self.node_cls_name = node_cls_name @@ -861,37 +1012,45 @@ def visit_Call(self, node): """Converts some_module.some_attr(some_args) into Node.from_func(some_module.some_attr)(some_args)""" node2 = ast.Call( func=ast.Call( - func=ast.Attribute(value=ast.Name(id=self.node_cls_name, ctx=ast.Load()), attr='from_func', ctx=ast.Load()), - args=[self.visit(node.func)], keywords=[]), - args=[self.visit(arg) for arg in node.args], - keywords=[self.visit(keyword) for keyword in node.keywords] + func=ast.Attribute( + value=ast.Name(id=self.node_cls_name, ctx=ast.Load()), + attr="from_func", + ctx=ast.Load(), + ), + args=[self.visit(node.func)], + keywords=[], + ), + args=[self.visit(arg) for arg in node.args], + keywords=[self.visit(keyword) for keyword in node.keywords], ) - + ast.fix_missing_locations(node2) - + return node2 - + def visit_Assign(self, node): """Converts some_module.some_attr(some_args) into Node.from_func(some_module.some_attr)(some_args)""" - + if self.assign_fn is None: return self.generic_visit(node) if len(node.targets) > 1 or not isinstance(node.targets[0], ast.Name): return self.generic_visit(node) - + node.value = ast.Call( func=ast.Name(id=self.assign_fn, ctx=ast.Load()), args=[], keywords=[ - ast.keyword(arg='value', value=self.visit(node.value)), - ast.keyword(arg='var_name', value=ast.Constant(value=node.targets[0].id)) + ast.keyword(arg="value", value=self.visit(node.value)), + ast.keyword( + arg="var_name", value=ast.Constant(value=node.targets[0].id) + ), ], ) - + ast.fix_missing_locations(node.value) - + return node - + def visit_List(self, node): """Converts the list syntax into a call to the ListSyntaxNode.""" if all(isinstance(elt, ast.Constant) for elt in node.elts): @@ -900,13 +1059,13 @@ def visit_List(self, node): new_node = ast.Call( func=ast.Name(id="ListSyntaxNode", ctx=ast.Load()), args=[self.visit(elt) for elt in node.elts], - keywords=[] + keywords=[], ) ast.fix_missing_locations(new_node) return new_node - + def visit_Tuple(self, node): """Converts the tuple syntax into a call to the TupleSyntaxNode.""" if all(isinstance(elt, ast.Constant) for elt in node.elts): @@ -915,20 +1074,20 @@ def visit_Tuple(self, node): new_node = ast.Call( func=ast.Name(id="TupleSyntaxNode", ctx=ast.Load()), args=[self.visit(elt) for elt in node.elts], - keywords=[] + keywords=[], ) ast.fix_missing_locations(new_node) return new_node - + def visit_Dict(self, node: ast.Dict) -> Any: """Converts the dict syntax into a call to the DictSyntaxNode.""" if all(isinstance(elt, ast.Constant) for elt in node.values): return self.generic_visit(node) if not all(isinstance(elt, ast.Constant) for elt in node.keys): return self.generic_visit(node) - + new_node = ast.Call( func=ast.Name(id="DictSyntaxNode", ctx=ast.Load()), args=[], @@ -941,20 +1100,19 @@ def visit_Dict(self, node: ast.Dict) -> Any: ast.fix_missing_locations(new_node) return new_node - - + def nodify_func( - func: FunctionType, - transformer_cls: Type[NodeConverter] = NodeConverter, - assign_fn: Union[Callable, None] = None, - node_cls: Type[Node] = Node + func: FunctionType, + transformer_cls: Type[NodeConverter] = NodeConverter, + assign_fn: Union[Callable, None] = None, + node_cls: Type[Node] = Node, ) -> FunctionType: """Converts all calculations of a function into nodes. - + This is used for example to convert a function into a workflow. - The conversion is done by getting the function's source code, parsing it + The conversion is done by getting the function's source code, parsing it into an abstract syntax tree, modifying the tree and recompiling. Parameters @@ -973,13 +1131,17 @@ def nodify_func( """ # Get the function's namespace. closurevars = inspect.getclosurevars(func) - func_namespace = {**closurevars.nonlocals, **closurevars.globals, **closurevars.builtins} + func_namespace = { + **closurevars.nonlocals, + **closurevars.globals, + **closurevars.builtins, + } # Get the function's source code. code = inspect.getsource(func) # Make sure the first line is at the 0 indentation level. code = textwrap.dedent(code) - + # Parse the source code into an AST. tree = ast.parse(code) @@ -987,9 +1149,11 @@ def nodify_func( # support arbitrary decorators. decorators = tree.body[0].decorator_list if len(decorators) > 0: - warn(f"Decorators are ignored for now on workflow creation. Ignoring {len(decorators)} decorators on {func.__name__}") + warn( + f"Decorators are ignored for now on workflow creation. Ignoring {len(decorators)} decorators on {func.__name__}" + ) tree.body[0].decorator_list = [] - + # The alias of the assign_fn function, which we make sure does not conflict # with any other variable in the function's namespace. assign_fn_key = None @@ -1003,21 +1167,23 @@ def nodify_func( node_cls_name = node_cls.__name__ while assign_fn_key in func_namespace: assign_fn_key += "_" - + # Transform the AST. transformer = transformer_cls(assign_fn=assign_fn_key, node_cls_name=node_cls_name) new_tree = transformer.visit(tree) # Compile the new AST into a code object. The filename is fake, but it doesn't - # matter because there is no file to map the code to. + # matter because there is no file to map the code to. # (we could map it to the original function in the future) code_obj = compile(new_tree, "compiled_workflows", "exec") - + # Add the needed variables into the namespace. namespace = { - node_cls_name: node_cls, - "ListSyntaxNode": ListSyntaxNode, "TupleSyntaxNode": TupleSyntaxNode, "DictSyntaxNode": DictSyntaxNode, - **func_namespace, + node_cls_name: node_cls, + "ListSyntaxNode": ListSyntaxNode, + "TupleSyntaxNode": TupleSyntaxNode, + "DictSyntaxNode": DictSyntaxNode, + **func_namespace, } if assign_fn_key is not None: namespace[assign_fn_key] = assign_fn diff --git a/src/sisl/oplist.py b/src/sisl/oplist.py index 59768350fd..c78d18fa51 100644 --- a/src/sisl/oplist.py +++ b/src/sisl/oplist.py @@ -20,7 +20,7 @@ def yield_oplist(oplist_, rhs): - """ Yield elements from `oplist_` and `rhs` + """Yield elements from `oplist_` and `rhs` This also ensures that ``len(oplist_)`` and ``len(rhs)`` are the same. @@ -31,24 +31,30 @@ def yield_oplist(oplist_, rhs): """ n_lhs = len(oplist_) n = 0 - for l, r in zip_longest(oplist_, rhs, fillvalue=0.): + for l, r in zip_longest(oplist_, rhs, fillvalue=0.0): n += 1 if n_lhs >= n: yield l, r if n_lhs != n: - raise ValueError(f"{oplist_.__class__.__name__} requires other data to contain same number of elements (or a scalar).") + raise ValueError( + f"{oplist_.__class__.__name__} requires other data to contain same number of elements (or a scalar)." + ) def _crt_op(op, unitary=False): if unitary: + def func_op(self): return self.__class__(op(s) for s in self) + else: + def func_op(self, other): if isiterable(other): return self.__class__(op(s, o) for s, o in yield_oplist(self, other)) return self.__class__(op(s, other) for s in self) + return func_op @@ -57,30 +63,35 @@ def func_op(self, other): if isiterable(other): return self.__class__(op(o, s) for s, o in yield_oplist(self, other)) return self.__class__(op(other, s) for s in self) + return func_op def _crt_iop(op, unitary=False): if unitary: + def func_op(self): - for i in range(len(self)): # pylint: disable=C0200 + for i in range(len(self)): # pylint: disable=C0200 self[i] = op(self[i]) return self + else: + def func_op(self, other): if isiterable(other): for i, (s, o) in enumerate(yield_oplist(self, other)): self[i] = op(s, o) else: - for i in range(len(self)): # pylint: disable=C0200 + for i in range(len(self)): # pylint: disable=C0200 self[i] = op(self[i], other) return self + return func_op @set_module("sisl") class oplist(list): - """ list with element-wise operations + """list with element-wise operations List-inherited class implementing direct element operations instead of list-extensions/compressions. When having multiple lists and one wishes to create a sum of individual elements, thus @@ -143,11 +154,12 @@ class oplist(list): iterable : data elements in `oplist` """ + __slots__ = () @classmethod def decorate(cls, func): - """ Decorate a function to always return an `oplist`, regardless of return values from `func` + """Decorate a function to always return an `oplist`, regardless of return values from `func` Parameters ---------- @@ -171,6 +183,7 @@ def decorate(cls, func): [1] """ + @wraps(func) def wrap_func(*args, **kwargs): val = func(*args, **kwargs) diff --git a/src/sisl/orbital.py b/src/sisl/orbital.py index 03bf381aea..bde0f604cc 100644 --- a/src/sisl/orbital.py +++ b/src/sisl/orbital.py @@ -33,10 +33,13 @@ from .utils.mathematics import cart2spher __all__ = [ - "Orbital", "SphericalOrbital", "AtomicOrbital", + "Orbital", + "SphericalOrbital", + "AtomicOrbital", "HydrogenicOrbital", - "GTOrbital", "STOrbital", - "radial_minimize_range" + "GTOrbital", + "STOrbital", + "radial_minimize_range", ] @@ -44,10 +47,11 @@ def _rfact(l, m): pi4 = 4 * pi if m == 0: - return msqrt((2*l + 1)/pi4) + return msqrt((2 * l + 1) / pi4) elif m < 0: - return -msqrt(2*(2*l + 1)/pi4 * fact(l-m)/fact(l+m)) * (-1) ** m - return msqrt(2*(2*l + 1)/pi4 * fact(l-m)/fact(l+m)) + return -msqrt(2 * (2 * l + 1) / pi4 * fact(l - m) / fact(l + m)) * (-1) ** m + return msqrt(2 * (2 * l + 1) / pi4 * fact(l - m) / fact(l + m)) + # This is a tuple of dicts # [0]{0} is l==0, m==0 @@ -57,10 +61,7 @@ def _rfact(l, m): # Calculate it up to l == 7 which is the j shell # It will never be used, but in case somebody wishes to play with spherical harmonics # then why not ;) -_rspher_harm_fact = tuple( - {m: _rfact(l, m) for m in range(-l, l+1)} - for l in range(8) -) +_rspher_harm_fact = tuple({m: _rfact(l, m) for m in range(-l, l + 1)} for l in range(8)) # Clean-up del _rfact @@ -97,13 +98,13 @@ def _rspherical_harm(m, l, theta, cos_phi): if m == 0: return _rspher_harm_fact[l][m] * lpmv(m, l, cos_phi) elif m < 0: - return _rspher_harm_fact[l][m] * (lpmv(m, l, cos_phi) * sin(m*theta)) - return _rspher_harm_fact[l][m] * (lpmv(m, l, cos_phi) * cos(m*theta)) + return _rspher_harm_fact[l][m] * (lpmv(m, l, cos_phi) * sin(m * theta)) + return _rspher_harm_fact[l][m] * (lpmv(m, l, cos_phi) * cos(m * theta)) @set_module("sisl") class Orbital: - r""" Base class for orbital information. + r"""Base class for orbital information. The orbital class is still in an experimental stage and will probably evolve over some time. @@ -155,8 +156,8 @@ class Orbital: """ __slots__ = ("_R", "_tag", "_q0") - def __init__(self, R, q0=0., tag=""): - """ Initialize orbital object """ + def __init__(self, R, q0=0.0, tag=""): + """Initialize orbital object""" # Determine if the orbital has a radial function # In which case we can apply the radial discovery if R is None: @@ -172,7 +173,9 @@ def __init__(self, R, q0=0., tag=""): R = radial_minimize_range(self._radial, **R) elif isinstance(R, dict): - warn(f"{self.__class__.__name__} cannot optimize R without a radial function.") + warn( + f"{self.__class__.__name__} cannot optimize R without a radial function." + ) R = R.get("contains", -0.9999) self._R = float(R) @@ -184,21 +187,21 @@ def __hash__(self): @property def R(self): - """ Maxmimum radius of orbital """ + """Maxmimum radius of orbital""" return self._R @property def q0(self): - """ Initial charge """ + """Initial charge""" return self._q0 @property def tag(self): - """ Named tag of orbital """ + """Named tag of orbital""" return self._tag def __str__(self): - """ A string representation of the object """ + """A string representation of the object""" if len(self.tag) > 0: return f"{self.__class__.__name__}{{R: {self.R:.5f}, q0: {self.q0}, tag: {self.tag}}}" return f"{self.__class__.__name__}{{R: {self.R:.5f}, q0: {self.q0}}}" @@ -209,15 +212,15 @@ def __repr__(self): return f"<{self.__module__}.{self.__class__.__name__} R={self.R:.3f}, q0={self.q0}>" def name(self, tex=False): - """ Return a named specification of the orbital (`tag`) """ + """Return a named specification of the orbital (`tag`)""" return self.tag def psi(self, r, *args, **kwargs): - r""" Calculate :math:`\phi(\mathbf R)` for Cartesian coordinates """ + r"""Calculate :math:`\phi(\mathbf R)` for Cartesian coordinates""" raise NotImplementedError def toSphere(self, center=None): - """ Return a sphere with radius equal to the orbital size + """Return a sphere with radius equal to the orbital size Returns ------- @@ -227,7 +230,7 @@ def toSphere(self, center=None): return Sphere(self.R, center) def equal(self, other, psi=False, radial=False): - """ Compare two orbitals by comparing their radius, and possibly the radial and psi functions + """Compare two orbitals by comparing their radius, and possibly the radial and psi functions When comparing two orbital radius they are considered *equal* with a precision of 1e-4 Ang. @@ -264,21 +267,21 @@ def equal(self, other, psi=False, radial=False): return same and self.tag == other.tag def copy(self): - """ Create an exact copy of this object """ + """Create an exact copy of this object""" return self.__class__(self.R, self.q0, self.tag) def scale(self, scale): - """ Scale the orbital by extending R by `scale` """ + """Scale the orbital by extending R by `scale`""" R = self.R * scale if R < 0: - R = -1. + R = -1.0 return self.__class__(R, self.q0, self.tag) def __eq__(self, other): return self.equal(other) def __plot__(self, harmonics=False, axes=False, *args, **kwargs): - """ Plot the orbital radial/spherical harmonics + """Plot the orbital radial/spherical harmonics Parameters ---------- @@ -299,7 +302,6 @@ def __plot__(self, harmonics=False, axes=False, *args, **kwargs): # Add plots if harmonics: - # Calculate the spherical harmonics theta, phi = np.meshgrid(np.arange(360), np.arange(180) - 90) s = self.spher(np.radians(theta), np.radians(phi)) @@ -310,8 +312,8 @@ def __plot__(self, harmonics=False, axes=False, *args, **kwargs): axes.get_figure().colorbar(cax) axes.set_title(r"${}$".format(self.name(True))) # I don't know how exactly to handle this... - #axes.set_xlabel(r"Azimuthal angle $\theta$") - #axes.set_ylabel(r"Polar angle $\phi$") + # axes.set_xlabel(r"Azimuthal angle $\theta$") + # axes.set_ylabel(r"Polar angle $\phi$") else: # Plot the radial function and 5% above 0 value @@ -324,8 +326,8 @@ def __plot__(self, harmonics=False, axes=False, *args, **kwargs): return axes - def toGrid(self, precision=0.05, c=1., R=None, dtype=np.float64, atom=1): - """ Create a Grid with *only* this orbital wavefunction on it + def toGrid(self, precision=0.05, c=1.0, R=None, dtype=np.float64, atom=1): + """Create a Grid with *only* this orbital wavefunction on it Parameters ---------- @@ -344,8 +346,10 @@ def toGrid(self, precision=0.05, c=1., R=None, dtype=np.float64, atom=1): if R is None: R = self.R if R < 0: - raise ValueError(f"{self.__class__.__name__}.toGrid was unable to create " - "the orbital grid for plotting, the box size is negative.") + raise ValueError( + f"{self.__class__.__name__}.toGrid was unable to create " + "the orbital grid for plotting, the box size is negative." + ) # Since all these things depend on other elements # we will simply import them here. @@ -354,7 +358,8 @@ def toGrid(self, precision=0.05, c=1., R=None, dtype=np.float64, atom=1): from .grid import Grid from .lattice import Lattice from .physics.electron import wavefunction - lattice = Lattice(R*2, origin=[-R] * 3) + + lattice = Lattice(R * 2, origin=[-R] * 3) if isinstance(atom, Atom): atom = atom.copy(orbitals=self) else: @@ -365,22 +370,25 @@ def toGrid(self, precision=0.05, c=1., R=None, dtype=np.float64, atom=1): return G def __getstate__(self): - """ Return the state of this object """ + """Return the state of this object""" return {"R": self.R, "q0": self.q0, "tag": self.tag} def __setstate__(self, d): - """ Re-create the state of this object """ + """Re-create the state of this object""" self.__init__(d["R"], q0=d["q0"], tag=d["tag"]) - RadialFuncT = Callable[[npt.ArrayLike], npt.NDArray] -def radial_minimize_range(radial_func: Callable[[RadialFuncT], npt.NDArray], - contains: float, - dr: Tuple[float, float]=(0.01, 0.0001), - maxR: float=100, - func: Optional[Callable[[RadialFuncT, npt.ArrayLike], npt.NDArray]]=None) -> float: - """ Minimize the maximum radius such that the integrated function `radial_func**2*r**3` contains `contains` of the integrand + + +def radial_minimize_range( + radial_func: Callable[[RadialFuncT], npt.NDArray], + contains: float, + dr: Tuple[float, float] = (0.01, 0.0001), + maxR: float = 100, + func: Optional[Callable[[RadialFuncT, npt.ArrayLike], npt.NDArray]] = None, +) -> float: + """Minimize the maximum radius such that the integrated function `radial_func**2*r**3` contains `contains` of the integrand Parameters ---------- @@ -444,7 +452,7 @@ def loc(intf, integrand): idx = 0 return idx + (intf[idx:] >= integrand).nonzero()[0] - r = np.arange(0., maxR + dr[0]/2, dr[0]) + r = np.arange(0.0, maxR + dr[0] / 2, dr[0]) f = func(radial_func, r) intf = cumulative_trapezoid(f, dx=dr[0], initial=0) integrand = intf[-1] * contains @@ -458,12 +466,12 @@ def loc(intf, integrand): # in the trapezoid integration each point is half contributed # to the previous point and half to the following point. # Here intf[idx-1] is the closed integral from 0:r[idx-1] - idxm_integrand = intf[idx-1] + idxm_integrand = intf[idx - 1] # Preset R R = r[idx] - r = np.arange(R - dr[0], min(R + dr[0]*2, maxR) + dr[1]/2, dr[1]) + r = np.arange(R - dr[0], min(R + dr[0] * 2, maxR) + dr[1] / 2, dr[1]) f = func(radial_func, r) intf = cumulative_trapezoid(f, dx=dr[1], initial=0) + idxm_integrand @@ -478,12 +486,14 @@ def loc(intf, integrand): except AttributeError: func_name = radial_func.__name__ - warn(f"{func_name} failed to detect a proper radius for integration purposes, retaining R=-{contains}") + warn( + f"{func_name} failed to detect a proper radius for integration purposes, retaining R=-{contains}" + ) return -contains def _set_radial(self, *args, **kwargs) -> None: - r""" Update the internal radial function used as a :math:`f(|\mathbf r|)` + r"""Update the internal radial function used as a :math:`f(|\mathbf r|)` This can be called in several ways: @@ -540,9 +550,11 @@ def _set_radial(self, *args, **kwargs) -> None: True """ if len(args) == 0: + def f0(r): - """ Wrapper for returning 0s """ + """Wrapper for returning 0s""" return np.zeros_like(r) + self._radial = f0 # we cannot set R since it will always give the largest distance @@ -550,7 +562,6 @@ def f0(r): self._radial = args[0] elif len(args) > 1: - # A radial and function component has been passed r = _a.asarrayd(args[0]) f = _a.asarrayd(args[1]) @@ -570,11 +581,13 @@ def f0(r): # this will defer the actual R designation (whether it should be set or not) self._radial = interp else: - raise ValueError(f"{self.__class__.__name__}.set_radial could not determine the arguments, please correct.") + raise ValueError( + f"{self.__class__.__name__}.set_radial could not determine the arguments, please correct." + ) def _radial(self, r, *args, **kwargs) -> np.ndarray: - r""" Calculate the radial part of spherical orbital :math:`R(\mathbf r)` + r"""Calculate the radial part of spherical orbital :math:`R(\mathbf r)` The position `r` is a vector from the origin of this orbital. @@ -609,7 +622,7 @@ def _radial(self, r, *args, **kwargs) -> np.ndarray: @set_module("sisl") class SphericalOrbital(Orbital): - r""" An *arbitrary* orbital class which only contains the harmonical part of the wavefunction where :math:`\phi(\mathbf r)=f(|\mathbf r|)Y_l^m(\theta,\varphi)` + r"""An *arbitrary* orbital class which only contains the harmonical part of the wavefunction where :math:`\phi(\mathbf r)=f(|\mathbf r|)Y_l^m(\theta,\varphi)` Note that in this case the used spherical harmonics is: @@ -659,8 +672,8 @@ class SphericalOrbital(Orbital): # Additional slots (inherited classes retain the same slots) __slots__ = ("_l", "_radial") - def __init__(self, l, rf_or_func, q0=0., tag="", **kwargs): - """ Initialize spherical orbital object """ + def __init__(self, l, rf_or_func, q0=0.0, tag="", **kwargs): + """Initialize spherical orbital object""" self._l = l # Set the internal function @@ -688,7 +701,7 @@ def __hash__(self): radial = _radial def spher(self, theta, phi, m=0, cos_phi=False): - r""" Calculate the spherical harmonics of this orbital at a given point (in spherical coordinates) + r"""Calculate the spherical harmonics of this orbital at a given point (in spherical coordinates) Parameters ----------- @@ -712,7 +725,7 @@ def spher(self, theta, phi, m=0, cos_phi=False): return _rspherical_harm(m, self.l, theta, cos(phi)) def psi(self, r, m=0): - r""" Calculate :math:`\phi(\mathbf R)` at a given point (or more points) + r"""Calculate :math:`\phi(\mathbf R)` at a given point (or more points) The position `r` is a vector from the origin of this orbital. @@ -741,7 +754,7 @@ def psi(self, r, m=0): return p def psi_spher(self, r, theta, phi, m=0, cos_phi=False): - r""" Calculate :math:`\phi(|\mathbf R|, \theta, \phi)` at a given point (in spherical coordinates) + r"""Calculate :math:`\phi(|\mathbf R|, \theta, \phi)` at a given point (in spherical coordinates) This is equivalent to `psi` however, the input is given in spherical coordinates. @@ -768,15 +781,15 @@ def psi_spher(self, r, theta, phi, m=0, cos_phi=False): @property def l(self): - r""" :math:`l` quantum number """ + r""":math:`l` quantum number""" return self._l def copy(self): - """ Create an exact copy of this object """ + """Create an exact copy of this object""" return self.__class__(self.l, self._radial, R=self.R, q0=self.q0, tag=self.tag) def equal(self, other, psi=False, radial=False): - """ Compare two orbitals by comparing their radius, and possibly the radial and psi functions + """Compare two orbitals by comparing their radius, and possibly the radial and psi functions Parameters ---------- @@ -795,7 +808,7 @@ def equal(self, other, psi=False, radial=False): return same def __str__(self): - """ A string representation of the object """ + """A string representation of the object""" if len(self.tag) > 0: return f"{self.__class__.__name__}{{l: {self.l}, R: {self.R}, q0: {self.q0}, tag: {self.tag}}}" return f"{self.__class__.__name__}{{l: {self.l}, R: {self.R}, q0: {self.q0}}}" @@ -806,7 +819,7 @@ def __repr__(self): return f"<{self.__module__}.{self.__class__.__name__} l={self.l}, R={self.R:.3f}, q0={self.q0}>" def toAtomicOrbital(self, m=None, n=None, zeta=1, P=False, q0=None): - r""" Create a list of `AtomicOrbital` objects + r"""Create a list of `AtomicOrbital` objects This defaults to create a list of `AtomicOrbital` objects for every `m` (for m in -l:l). One may optionally specify the sub-set of `m` to retrieve. @@ -835,20 +848,27 @@ def toAtomicOrbital(self, m=None, n=None, zeta=1, P=False, q0=None): if m is None: m = range(-self.l, self.l + 1) elif isinstance(m, Integral): - return AtomicOrbital(n=n, l=self.l, m=m, zeta=zeta, P=P, spherical=self, q0=q0, R=self.R) - return [AtomicOrbital(n=n, l=self.l, m=mm, zeta=zeta, P=P, spherical=self, q0=q0, R=self.R) for mm in m] + return AtomicOrbital( + n=n, l=self.l, m=m, zeta=zeta, P=P, spherical=self, q0=q0, R=self.R + ) + return [ + AtomicOrbital( + n=n, l=self.l, m=mm, zeta=zeta, P=P, spherical=self, q0=q0, R=self.R + ) + for mm in m + ] def __getstate__(self): - """ Return the state of this object """ + """Return the state of this object""" # A function is not necessarily pickable, so we store interpolated # data which *should* ensure the correct pickable state (to close agreement) r = np.linspace(0, self.R, 1000) f = self.radial(r) - return {'l': self.l, 'r': r, 'f': f, 'q0': self.q0, 'tag': self.tag} + return {"l": self.l, "r": r, "f": f, "q0": self.q0, "tag": self.tag} def __setstate__(self, d): - """ Re-create the state of this object """ - self.__init__(d['l'], (d['r'], d['f']), q0=d['q0'], tag=d['tag']) + """Re-create the state of this object""" + self.__init__(d["l"], (d["r"], d["f"]), q0=d["q0"], tag=d["tag"]) @set_module("sisl") @@ -912,7 +932,7 @@ class AtomicOrbital(Orbital): __slots__ = ("_n", "_l", "_m", "_zeta", "_P", "_orb") def __init__(self, *args, **kwargs): - """ Initialize atomic orbital object """ + """Initialize atomic orbital object""" # Ensure args is a list (to be able to pop) args = list(args) @@ -932,13 +952,32 @@ def __init__(self, *args, **kwargs): _n = {"s": 1, "p": 2, "d": 3, "f": 4, "g": 5} _l = {"s": 0, "p": 1, "d": 2, "f": 3, "g": 4} - _m = {"s": 0, - "pz": 0, "px": 1, "py": -1, - "dxy": -2, "dyz": -1, "dz2": 0, "dxz": 1, "dx2-y2": 2, - "fy(3x2-y2)": -3, "fxyz": -2, "fz2y": -1, "fz3": 0, - "fz2x": 1, "fz(x2-y2)": 2, "fx(x2-3y2)": 3, - "gxy(x2-y2)": -4, "gzy(3x2-y2)": -3, "gz2xy": -2, "gz3y": -1, "gz4": 0, - "gz3x": 1, "gz2(x2-y2)": 2, "gzx(x2-3y2)": 3, "gx4+y4": 4, + _m = { + "s": 0, + "pz": 0, + "px": 1, + "py": -1, + "dxy": -2, + "dyz": -1, + "dz2": 0, + "dxz": 1, + "dx2-y2": 2, + "fy(3x2-y2)": -3, + "fxyz": -2, + "fz2y": -1, + "fz3": 0, + "fz2x": 1, + "fz(x2-y2)": 2, + "fx(x2-3y2)": 3, + "gxy(x2-y2)": -4, + "gzy(3x2-y2)": -3, + "gz2xy": -2, + "gz3y": -1, + "gz4": 0, + "gz3x": 1, + "gz2(x2-y2)": 2, + "gzx(x2-3y2)": 3, + "gx4+y4": 4, } # First remove a P for polarization @@ -970,18 +1009,17 @@ def __init__(self, *args, **kwargs): # However, for now we assume this is enough (could easily # be extended by a reg-exp) try: - zeta = int(s[iZ+1]) + zeta = int(s[iZ + 1]) # Remove Z + int - s = s[:iZ] + s[iZ+2:] + s = s[:iZ] + s[iZ + 2 :] except Exception: zeta = 1 - s = s[:iZ] + s[iZ+1:] + s = s[:iZ] + s[iZ + 1 :] # We should be left with m specification m = _m.get(s, m) else: - # Arguments *have* to be # n, l, [m (only for l>0)] [, zeta [, P]] if n is None and len(args) > 0: @@ -1000,7 +1038,6 @@ def __init__(self, *args, **kwargs): if isinstance(args[0], bool): P = args.pop(0) - if l is None: raise ValueError(f"{self.__class__.__name__} l is not defined") @@ -1022,11 +1059,13 @@ def __init__(self, *args, **kwargs): raise ValueError(f"{self.__class__.__name__} n must be >= 1") if self.l >= len(_rspher_harm_fact): - raise ValueError(f"{self.__class__.__name__} does not implement shells l>={len(_rspher_harm_fact)}!") + raise ValueError( + f"{self.__class__.__name__} does not implement shells l>={len(_rspher_harm_fact)}!" + ) if abs(self.m) > self.l: raise ValueError(f"{self.__class__.__name__} requires |m| <= l.") - # Now we should figure out how the spherical orbital + # Now we should figure out how the spherical orbital # has been passed. # There are two options: # 1. The radial function is passed as two arrays: r, f @@ -1036,8 +1075,10 @@ def __init__(self, *args, **kwargs): if len(args) > 0: s = args.pop(0) if "spherical" in kwargs: - raise ValueError(f"{self.__class__.__name__} multiple values for the spherical " - "orbital is present, 1) argument, 2) spherical=. Only supply one of them.") + raise ValueError( + f"{self.__class__.__name__} multiple values for the spherical " + "orbital is present, 1) argument, 2) spherical=. Only supply one of them." + ) else: # in case the class has its own radial implementation, we might as well rely on that one @@ -1045,7 +1086,7 @@ def __init__(self, *args, **kwargs): # Get the radius requested R = kwargs.get("R") - q0 = kwargs.get("q0", 0.) + q0 = kwargs.get("q0", 0.0) if s is None: self._orb = Orbital(R, q0=q0) @@ -1058,45 +1099,63 @@ def __init__(self, *args, **kwargs): super().__init__(self._orb.R, q0=q0, tag=kwargs.get("tag", "")) def __hash__(self): - return hash((super(Orbital, self), self._l, self._n, self._m, - self._zeta, self._P, self._orb)) + return hash( + ( + super(Orbital, self), + self._l, + self._n, + self._m, + self._zeta, + self._P, + self._orb, + ) + ) @property def n(self): - r""" :math:`n` shell """ + r""":math:`n` shell""" return self._n @property def l(self): - r""" :math:`l` quantum number """ + r""":math:`l` quantum number""" return self._l @property def m(self): - r""" :math:`m` quantum number """ + r""":math:`m` quantum number""" return self._m @property def zeta(self): - r""" :math:`\zeta` shell """ + r""":math:`\zeta` shell""" return self._zeta @property def P(self): - r""" Whether this is polarized shell or not """ + r"""Whether this is polarized shell or not""" return self._P @property def orb(self): - r""" Orbital with radial part """ + r"""Orbital with radial part""" return self._orb def copy(self): - """ Create an exact copy of this object """ - return self.__class__(n=self.n, l=self.l, m=self.m, zeta=self.zeta, P=self.P, spherical=self.orb.copy(), q0=self.q0, tag=self.tag) + """Create an exact copy of this object""" + return self.__class__( + n=self.n, + l=self.l, + m=self.m, + zeta=self.zeta, + P=self.P, + spherical=self.orb.copy(), + q0=self.q0, + tag=self.tag, + ) def equal(self, other, psi=False, radial=False): - """ Compare two orbitals by comparing their radius, and possibly the radial and psi functions + """Compare two orbitals by comparing their radius, and possibly the radial and psi functions Parameters ---------- @@ -1121,33 +1180,61 @@ def equal(self, other, psi=False, radial=False): return same def name(self, tex=False): - """ Return named specification of the atomic orbital """ + """Return named specification of the atomic orbital""" if tex: name = "{}{}".format(self.n, "spdfghij"[self.l]) if self.l == 1: - name += ("_y", "_z", "_x")[self.m+1] + name += ("_y", "_z", "_x")[self.m + 1] elif self.l == 2: - name += ("_{xy}", "_{yz}", "_{z^2}", "_{xz}", "_{x^2-y^2}")[self.m+2] + name += ("_{xy}", "_{yz}", "_{z^2}", "_{xz}", "_{x^2-y^2}")[self.m + 2] elif self.l == 3: - name += ("_{y(3x^2-y^2)}", "_{xyz}", "_{z^2y}", "_{z^3}", "_{z^2x}", "_{z(x^2-y^2)}", "_{x(x^2-3y^2)}")[self.m+3] + name += ( + "_{y(3x^2-y^2)}", + "_{xyz}", + "_{z^2y}", + "_{z^3}", + "_{z^2x}", + "_{z(x^2-y^2)}", + "_{x(x^2-3y^2)}", + )[self.m + 3] elif self.l == 4: - name += ("_{_{xy(x^2-y^2)}}", "_{zy(3x^2-y^2)}", "_{z^2xy}", "_{z^3y}", "_{z^4}", - "_{z^3x}", "_{z^2(x^2-y^2)}", "_{zx(x^2-3y^2)}", "_{x^4+y^4}")[self.m+4] + name += ( + "_{_{xy(x^2-y^2)}}", + "_{zy(3x^2-y^2)}", + "_{z^2xy}", + "_{z^3y}", + "_{z^4}", + "_{z^3x}", + "_{z^2(x^2-y^2)}", + "_{zx(x^2-3y^2)}", + "_{x^4+y^4}", + )[self.m + 4] elif self.l >= 5: name = f"{name}_{{m={self.m}}}" if self.P: - return name + fr"\zeta^{self.zeta}\mathrm{{P}}" - return name + fr"\zeta^{self.zeta}" + return name + rf"\zeta^{self.zeta}\mathrm{{P}}" + return name + rf"\zeta^{self.zeta}" name = "{}{}".format(self.n, "spdfghij"[self.l]) if self.l == 1: - name += ("y", "z", "x")[self.m+1] + name += ("y", "z", "x")[self.m + 1] elif self.l == 2: - name += ("xy", "yz", "z2", "xz", "x2-y2")[self.m+2] + name += ("xy", "yz", "z2", "xz", "x2-y2")[self.m + 2] elif self.l == 3: - name += ("y(3x2-y2)", "xyz", "z2y", "z3", "z2x", "z(x2-y2)", "x(x2-3y2)")[self.m+3] + name += ("y(3x2-y2)", "xyz", "z2y", "z3", "z2x", "z(x2-y2)", "x(x2-3y2)")[ + self.m + 3 + ] elif self.l == 4: - name += ("xy(x2-y2)", "zy(3x2-y2)", "z2xy", "z3y", "z4", - "z3x", "z2(x2-y2)", "zx(x2-3y2)", "x4+y4")[self.m+4] + name += ( + "xy(x2-y2)", + "zy(3x2-y2)", + "z2xy", + "z3y", + "z4", + "z3x", + "z2(x2-y2)", + "zx(x2-3y2)", + "x4+y4", + )[self.m + 4] elif self.l >= 5: name = f"{name}(m={self.m})" if self.P: @@ -1155,25 +1242,29 @@ def name(self, tex=False): return name + f"Z{self.zeta}" def __str__(self): - """ A string representation of the object """ + """A string representation of the object""" if len(self.tag) > 0: return f"{self.__class__.__name__}{{{self.name()}, q0: {self.q0}, tag: {self.tag}, {self.orb!s}}}" - return f"{self.__class__.__name__}{{{self.name()}, q0: {self.q0}, {self.orb!s}}}" + return ( + f"{self.__class__.__name__}{{{self.name()}, q0: {self.q0}, {self.orb!s}}}" + ) def __repr__(self): if self.tag: return f"<{self.__module__}.{self.__class__.__name__} {self.name()} q0={self.q0}, tag={self.tag}>" - return f"<{self.__module__}.{self.__class__.__name__} {self.name()} q0={self.q0}>" + return ( + f"<{self.__module__}.{self.__class__.__name__} {self.name()} q0={self.q0}>" + ) def set_radial(self, *args, **kwargs): - r""" Update the internal radial function used as a :math:`f(|\mathbf r|)` + r"""Update the internal radial function used as a :math:`f(|\mathbf r|)` See `SphericalOrbital.set_radial` where these arguments are passed to. """ return self.orb.set_radial(*args, **kwargs) def radial(self, r, *args, **kwargs): - r""" Calculate the radial part of the wavefunction :math:`f(\mathbf R)` + r"""Calculate the radial part of the wavefunction :math:`f(\mathbf R)` The position `r` is a vector from the origin of this orbital. @@ -1190,7 +1281,7 @@ def radial(self, r, *args, **kwargs): return self.orb.radial(r, *args, **kwargs) def psi(self, r): - r""" Calculate :math:`\phi(\mathbf r)` at a given point (or more points) + r"""Calculate :math:`\phi(\mathbf r)` at a given point (or more points) The position `r` is a vector from the origin of this orbital. @@ -1207,7 +1298,7 @@ def psi(self, r): return self.orb.psi(r, self.m) def spher(self, theta, phi, cos_phi=False): - r""" Calculate the spherical harmonics of this orbital at a given point (in spherical coordinates) + r"""Calculate the spherical harmonics of this orbital at a given point (in spherical coordinates) Parameters ----------- @@ -1227,7 +1318,7 @@ def spher(self, theta, phi, cos_phi=False): return self.orb.spher(theta, phi, self.m, cos_phi) def psi_spher(self, r, theta, phi, cos_phi=False): - r""" Calculate :math:`\phi(|\mathbf R|, \theta, \phi)` at a given point (in spherical coordinates) + r"""Calculate :math:`\phi(|\mathbf R|, \theta, \phi)` at a given point (in spherical coordinates) This is equivalent to `psi` however, the input is given in spherical coordinates. @@ -1251,7 +1342,7 @@ def psi_spher(self, r, theta, phi, cos_phi=False): return self.orb.psi_spher(r, theta, phi, self.m, cos_phi) def __getstate__(self): - """ Return the state of this object """ + """Return the state of this object""" # A function is not necessarily pickable, so we store interpolated # data which *should* ensure the correct pickable state (to close agreement) try: @@ -1264,7 +1355,7 @@ def __getstate__(self): return {"name": self.name(), "r": r, "f": f, "q0": self.q0, "tag": self.tag} def __setstate__(self, d): - """ Re-create the state of this object """ + """Re-create the state of this object""" if d["r"] is None: self.__init__(d["name"], q0=d["q0"], tag=d["tag"]) else: @@ -1319,41 +1410,51 @@ class HydrogenicOrbital(AtomicOrbital): """ def __init__(self, n, l, m, Z, **kwargs): - self._Z = Z Helper = namedtuple("Helper", ["Z", "prefactor"]) z = 2 * Z / (n * a0("Ang")) - pref = (z ** 3 * factorial(n - l - 1) / (2 * n * factorial(n + l))) ** 0.5 + pref = (z**3 * factorial(n - l - 1) / (2 * n * factorial(n + l))) ** 0.5 self._radial_helper = Helper(z, pref) super().__init__(n, l, m, **kwargs) def copy(self): - """ Create an exact copy of this object """ - return self.__class__(self.n, self.l, self.m, self._Z, R=self.R, q0=self.q0, tag=self.tag) + """Create an exact copy of this object""" + return self.__class__( + self.n, self.l, self.m, self._Z, R=self.R, q0=self.q0, tag=self.tag + ) def _radial(self, r): - r""" Radial functional for the Hydrogenic orbital """ + r"""Radial functional for the Hydrogenic orbital""" H = self._radial_helper n = self.n l = self.l zr = H.Z * r L = H.prefactor * eval_genlaguerre(n - l - 1, 2 * l + 1, zr) - return np.exp(-zr * 0.5) * zr ** l * L + return np.exp(-zr * 0.5) * zr**l * L def __getstate__(self): - """ Return the state of this object """ - return {"n": self.n, "l": self.l, "m": self.m, - "Z": self._Z, "R": self.R, "q0": self.q0, "tag": self.tag} + """Return the state of this object""" + return { + "n": self.n, + "l": self.l, + "m": self.m, + "Z": self._Z, + "R": self.R, + "q0": self.q0, + "tag": self.tag, + } def __setstate__(self, d): - """ Re-create the state of this object """ - self.__init__(d["n"], d["l"], d["m"], d["Z"], R=d["R"], q0=d["q0"], tag=d["tag"]) + """Re-create the state of this object""" + self.__init__( + d["n"], d["l"], d["m"], d["Z"], R=d["R"], q0=d["q0"], tag=d["tag"] + ) class _ExponentialOrbital(Orbital): - r""" Inheritable class for different exponential spherical orbitals + r"""Inheritable class for different exponential spherical orbitals All exponential spherical orbitals are defined using: @@ -1373,7 +1474,6 @@ class _ExponentialOrbital(Orbital): __slots__ = ("_n", "_l", "_m", "_alpha", "_coeff") def __init__(self, *args, **kwargs): - # Ensure args is a list (to be able to pop) args = list(args) @@ -1430,10 +1530,14 @@ def __init__(self, *args, **kwargs): coeff = (coeff,) self._coeff = tuple(coeff) - assert len(self.alpha) == len(self.coeff), "Contraction factors and exponents needs to have same length" + assert len(self.alpha) == len( + self.coeff + ), "Contraction factors and exponents needs to have same length" if self.l >= len(_rspher_harm_fact): - raise ValueError(f"{self.__class__.__name__} does not implement shells l>={len(_rspher_harm_fact)}!") + raise ValueError( + f"{self.__class__.__name__} does not implement shells l>={len(_rspher_harm_fact)}!" + ) if abs(self.m) > self.l: raise ValueError(f"{self.__class__.__name__} requires |m| <= l.") @@ -1442,16 +1546,27 @@ def __init__(self, *args, **kwargs): super().__init__(*args, R=R, **kwargs) def copy(self): - """ Create an exact copy of this object """ - return self.__class__(n=self.n, l=self.l, m=self.m, alpha=self.alpha, coeff=self.coeff, R=self.R, q0=self.q0, tag=self.tag) + """Create an exact copy of this object""" + return self.__class__( + n=self.n, + l=self.l, + m=self.m, + alpha=self.alpha, + coeff=self.coeff, + R=self.R, + q0=self.q0, + tag=self.tag, + ) def __str__(self): - """ A string representation of the object """ + """A string representation of the object""" if len(self.tag) > 0: s = f"{self.__class__.__name__}{{n: {self.n}, l: {self.l}, m: {self.m}, R: {self.R}, q0: {self.q0}, tag: {self.tag}" else: s = f"{self.__class__.__name__}{{n: {self.n}, l: {self.l}, m: {self.m}, R: {self.R}, q0: {self.q0}" - orbs = ",\n c, a:".join([f"{c:.4f} , {a:.5f}" for c, a in zip(self.alpha, self.coeff)]) + orbs = ",\n c, a:".join( + [f"{c:.4f} , {a:.5f}" for c, a in zip(self.alpha, self.coeff)] + ) return f"{s}{orbs}\n}}" def __repr__(self): @@ -1460,35 +1575,37 @@ def __repr__(self): return f"<{self.__module__}.{self.__class__.__name__} n={self.n}, l={self.l}, m={self.m}, no={len(self.alpha)}, R={self.R:.3f}, q0={self.q0}, >" def __hash__(self): - return hash((super(Orbital, self), self.n, self.l, self.m, self.coeff, self.alpha)) + return hash( + (super(Orbital, self), self.n, self.l, self.m, self.coeff, self.alpha) + ) @property def n(self): - r""" :math:`n` quantum number """ + r""":math:`n` quantum number""" return self._n @property def l(self): - r""" :math:`l` quantum number """ + r""":math:`l` quantum number""" return self._l @property def m(self): - r""" :math:`m` quantum number """ + r""":math:`m` quantum number""" return self._m @property def alpha(self): - r""" :math:`\alpha` factors """ + r""":math:`\alpha` factors""" return self._alpha @property def coeff(self): - r""" :math:`c` contraction factors """ + r""":math:`c` contraction factors""" return self._coeff def psi(self, r): - r""" Calculate :math:`\phi(\mathbf r)` at a given point (or more points) + r"""Calculate :math:`\phi(\mathbf r)` at a given point (or more points) The position `r` is a vector from the origin of this orbital. @@ -1505,7 +1622,9 @@ def psi(self, r): r = _a.asarray(r) s = r.shape[:-1] # Convert to spherical coordinates - n, idx, r, theta, phi = cart2spher(r, theta=self.m != 0, cos_phi=True, maxR=self.R) + n, idx, r, theta, phi = cart2spher( + r, theta=self.m != 0, cos_phi=True, maxR=self.R + ) p = _a.zerosd(n) if len(idx) > 0: p[idx] = self.psi_spher(r, theta, phi, cos_phi=True) @@ -1515,7 +1634,7 @@ def psi(self, r): return p def spher(self, theta, phi, cos_phi=False): - r""" Calculate the spherical harmonics of this orbital at a given point (in spherical coordinates) + r"""Calculate the spherical harmonics of this orbital at a given point (in spherical coordinates) Parameters ----------- @@ -1537,7 +1656,7 @@ def spher(self, theta, phi, cos_phi=False): return _rspherical_harm(self.m, self.l, theta, cos(phi)) def psi_spher(self, r, theta, phi, cos_phi=False): - r""" Calculate :math:`\phi(|\mathbf R|, \theta, \phi)` at a given point (in spherical coordinates) + r"""Calculate :math:`\phi(|\mathbf R|, \theta, \phi)` at a given point (in spherical coordinates) This is equivalent to `psi` however, the input is given in spherical coordinates. @@ -1609,16 +1728,16 @@ class GTOrbital(_ExponentialOrbital): radial = _radial def _radial(self, r): - r""" Radial function """ + r"""Radial function""" r2 = np.square(r) coeff = self.coeff alpha = self.alpha - v = coeff[0] * np.exp(-alpha[0]*r2) + v = coeff[0] * np.exp(-alpha[0] * r2) for c, a in zip(coeff[1:], alpha[1:]): - v += c * np.exp(-a*r2) + v += c * np.exp(-a * r2) if self.l == 0: return v - return r ** self.l * v + return r**self.l * v class STOrbital(_ExponentialOrbital): @@ -1669,12 +1788,12 @@ class STOrbital(_ExponentialOrbital): radial = _radial def _radial(self, r): - r""" Radial function """ + r"""Radial function""" coeff = self.coeff alpha = self.alpha - v = coeff[0] * np.exp(-alpha[0]*r) + v = coeff[0] * np.exp(-alpha[0] * r) for c, a in zip(coeff[1:], alpha[1:]): - v += c * np.exp(-a*r) + v += c * np.exp(-a * r) if self.n == 1: return v return r ** (self.n - 1) * v diff --git a/src/sisl/physics/__init__.py b/src/sisl/physics/__init__.py index ee01c67692..4af072d181 100644 --- a/src/sisl/physics/__init__.py +++ b/src/sisl/physics/__init__.py @@ -87,6 +87,7 @@ """ from . import electron, phonon + # Patch BrillouinZone objects and import apply classes from ._brillouinzone_apply import * from ._feature import * diff --git a/src/sisl/physics/_brillouinzone_apply.py b/src/sisl/physics/_brillouinzone_apply.py index 3035052460..162804f596 100644 --- a/src/sisl/physics/_brillouinzone_apply.py +++ b/src/sisl/physics/_brillouinzone_apply.py @@ -14,6 +14,7 @@ try: import xarray + _has_xarray = True except ImportError: _has_xarray = False @@ -38,7 +39,6 @@ __all__ += ["MonkhorstPackApply", "MonkhorstPackParentApply"] - def _asoplist(arg): if isinstance(arg, (tuple, list)) and not isinstance(arg, oplist): return oplist(arg) @@ -46,7 +46,7 @@ def _asoplist(arg): def _correct_str(orig, insert): - """ Correct string with `insert` """ + """Correct string with `insert`""" if len(insert) == 0: return orig i = orig.index("{") + 1 @@ -88,18 +88,18 @@ class BrillouinZoneApply(AbstractDispatch): @set_module("sisl.physics") class BrillouinZoneParentApply(BrillouinZoneApply): - - def __str__(self, message=''): + def __str__(self, message=""): return _correct_str(super().__str__(), message) def _parse_kwargs(self, wrap, eta=None, eta_key=""): - """ Parse kwargs """ + """Parse kwargs""" bz = self._obj parent = bz.parent if wrap is None: # we always return a wrap def wrap(v, parent=None, k=None, weight=None): return v + else: wrap = allow_kwargs("parent", "k", "weight")(wrap) eta = progressbar(len(bz), f"{bz.__class__.__name__}.{eta_key}", "k", eta) @@ -119,19 +119,27 @@ def __str__(self, message="iter"): return super().__str__(message) def dispatch(self, method, eta_key="iter"): - """ Dispatch the method by iterating values """ + """Dispatch the method by iterating values""" pool = _pool_procs(self._attrs.get("pool", None)) if pool is None: + @wraps(method) def func(*args, wrap=None, eta=None, **kwargs): bz, parent, wrap, eta = self._parse_kwargs(wrap, eta, eta_key=eta_key) k = bz.k w = bz.weight for i in range(len(k)): - yield wrap(method(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i]) + yield wrap( + method(*args, k=k[i], **kwargs), + parent=parent, + k=k[i], + weight=w[i], + ) eta.update() eta.close() + else: + @wraps(method) def func(*args, wrap=None, eta=None, **kwargs): pool.restart() @@ -140,7 +148,9 @@ def func(*args, wrap=None, eta=None, **kwargs): w = bz.weight def func(k, w): - return wrap(method(*args, k=k, **kwargs), parent=parent, k=k, weight=w) + return wrap( + method(*args, k=k, **kwargs), parent=parent, k=k, weight=w + ) yield from pool.imap(func, k, w) # TODO notify users that this may be bad when used with zip @@ -158,7 +168,7 @@ def __str__(self, message="sum over k"): return super().__str__(message) def dispatch(self, method): - """ Dispatch the method by summing """ + """Dispatch the method by summing""" iter_func = super().dispatch(method, eta_key="sum") @wraps(method) @@ -176,7 +186,7 @@ def __str__(self, message="None"): return super().__str__(message) def dispatch(self, method): - """ Dispatch the method by doing nothing (mostly useful if wrapped) """ + """Dispatch the method by doing nothing (mostly useful if wrapped)""" iter_func = super().dispatch(method, eta_key="none") @wraps(method) @@ -194,16 +204,20 @@ def __str__(self, message="list"): return super().__str__(message) def dispatch(self, method): - """ Dispatch the method by returning list of values """ + """Dispatch the method by returning list of values""" iter_func = super().dispatch(method, eta_key="list") if self._attrs.get("zip", self._attrs.get("unzip", False)): + @wraps(method) def func(*args, **kwargs): return zip(*(v for v in iter_func(*args, **kwargs))) + else: + @wraps(method) def func(*args, **kwargs): return [v for v in iter_func(*args, **kwargs)] + return func @@ -213,16 +227,20 @@ def __str__(self, message="oplist"): return super().__str__(message) def dispatch(self, method): - """ Dispatch the method by returning oplist of values """ + """Dispatch the method by returning oplist of values""" iter_func = super().dispatch(method, eta_key="oplist") if self._attrs.get("zip", self._attrs.get("unzip", False)): + @wraps(method) def func(*args, **kwargs): return oplist(zip(*(v for v in iter_func(*args, **kwargs)))) + else: + @wraps(method) def func(*args, **kwargs): return oplist(v for v in iter_func(*args, **kwargs)) + return func @@ -232,7 +250,7 @@ def __str__(self, message="ndarray"): return super().__str__(message) def dispatch(self, method, eta_key="ndarray"): - """ Dispatch the method by one array """ + """Dispatch the method by one array""" pool = _pool_procs(self._attrs.get("pool", None)) unzip = self._attrs.get("zip", self._attrs.get("unzip", False)) @@ -242,6 +260,7 @@ def _create_v(nk, v): return out if pool is None: + @wraps(method) def func(*args, wrap=None, eta=None, **kwargs): bz, parent, wrap, eta = self._parse_kwargs(wrap, eta, eta_key=eta_key) @@ -250,13 +269,20 @@ def func(*args, wrap=None, eta=None, **kwargs): w = bz.weight # Get first values - v = wrap(method(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0]) + v = wrap( + method(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0] + ) eta.update() if unzip: a = tuple(_create_v(nk, vi) for vi in v) for i in range(1, len(k)): - v = wrap(method(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i]) + v = wrap( + method(*args, k=k[i], **kwargs), + parent=parent, + k=k[i], + weight=w[i], + ) for ai, vi in zip(a, v): ai[i] = vi eta.update() @@ -264,12 +290,19 @@ def func(*args, wrap=None, eta=None, **kwargs): a = _create_v(nk, v) del v for i in range(1, len(k)): - a[i] = wrap(method(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i]) + a[i] = wrap( + method(*args, k=k[i], **kwargs), + parent=parent, + k=k[i], + weight=w[i], + ) eta.update() eta.close() return a + else: + @wraps(method) def func(*args, wrap=None, **kwargs): pool.restart() @@ -279,7 +312,9 @@ def func(*args, wrap=None, **kwargs): w = bz.weight def func(k, w): - return wrap(method(*args, k=k, **kwargs), parent=parent, k=k, weight=w) + return wrap( + method(*args, k=k, **kwargs), parent=parent, k=k, weight=w + ) it = pool.imap(func, k, w) v = next(it) @@ -293,7 +328,7 @@ def func(k, w): else: a = _create_v(nk, v) for i, v in enumerate(it): - a[i+1] = v + a[i + 1] = v del v pool.terminate() return a @@ -307,23 +342,46 @@ def __str__(self, message="average"): return super().__str__(message) def dispatch(self, method): - """ Dispatch the method by averaging """ + """Dispatch the method by averaging""" pool = _pool_procs(self._attrs.get("pool", None)) if pool is None: + @wraps(method) def func(*args, wrap=None, eta=None, **kwargs): bz, parent, wrap, eta = self._parse_kwargs(wrap, eta, eta_key="average") # Do actual average k = bz.k w = bz.weight - v = _asoplist(wrap(method(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0])) * w[0] + v = ( + _asoplist( + wrap( + method(*args, k=k[0], **kwargs), + parent=parent, + k=k[0], + weight=w[0], + ) + ) + * w[0] + ) eta.update() for i in range(1, len(k)): - v += _asoplist(wrap(method(*args, k=k[i], **kwargs), parent=parent, k=k[i], weight=w[i])) * w[i] + v += ( + _asoplist( + wrap( + method(*args, k=k[i], **kwargs), + parent=parent, + k=k[i], + weight=w[i], + ) + ) + * w[i] + ) eta.update() eta.close() return v + else: + @wraps(method) def func(*args, wrap=None, **kwargs): pool.restart() @@ -332,7 +390,10 @@ def func(*args, wrap=None, **kwargs): w = bz.weight def func(k, w): - return wrap(method(*args, k=k, **kwargs), parent=parent, k=k, weight=w) * w + return ( + wrap(method(*args, k=k, **kwargs), parent=parent, k=k, weight=w) + * w + ) iter_func = pool.uimap(func, k, w) avg = reduce(op.add, iter_func, _asoplist(next(iter_func))) @@ -348,16 +409,16 @@ def __str__(self, message="xarray"): return super().__str__(message) def dispatch(self, method): - """ Dispatch the method by returning a DataArray or data-set """ + """Dispatch the method by returning a DataArray or data-set""" def _fix_coords_dims(nk, array, coords, dims, prefix="v"): if coords is None and dims is None: # we need to manually create them - coords = [('k', _a.arangei(nk))] + coords = [("k", _a.arangei(nk))] for i, v in enumerate(array.shape[1:]): coords.append((f"{prefix}{i+1}", _a.arangei(v))) elif coords is None: - coords = [('k', _a.arangei(nk))] + coords = [("k", _a.arangei(nk))] for i, v in enumerate(array.shape[1:]): coords.append((dims[i], _a.arangei(v))) # everything is in coords, no need to pass dims @@ -374,7 +435,7 @@ def _fix_coords_dims(nk, array, coords, dims, prefix="v"): if isinstance(coords, str): coords = [coords] coords = list(coords) - coords.insert(0, ('k', _a.arangei(nk))) + coords.insert(0, ("k", _a.arangei(nk))) for i in range(1, len(coords)): if isinstance(coords[i], str): coords[i] = (coords[i], _a.arangei(array.shape[i])) @@ -397,20 +458,22 @@ def func(*args, coords=(), dims=(), name=method.__name__, **kwargs): array = array_func(*args, **kwargs) def _create_DA(array, coords, dims, name): - coords, dims = _fix_coords_dims(len(bz), array, coords, dims, - prefix=f"{name}.v") + coords, dims = _fix_coords_dims( + len(bz), array, coords, dims, prefix=f"{name}.v" + ) return xarray.DataArray(array, coords=coords, dims=dims, name=name) if isinstance(name, str): name = [f"{name}{i}" for i in range(len(array))] - data = {nam: _create_DA(arr, coord, dim, nam) - for arr, coord, dim, nam - in zip_longest(array, coords, dims, name) + data = { + nam: _create_DA(arr, coord, dim, nam) + for arr, coord, dim, nam in zip_longest(array, coords, dims, name) } - attrs = {'bz': bz, 'parent': bz.parent} + attrs = {"bz": bz, "parent": bz.parent} return xarray.Dataset(data, attrs=attrs) + else: array_func = super().dispatch(method, eta_key="dataarray") @@ -421,11 +484,14 @@ def func(*args, coords=None, dims=None, name=method.__name__, **kwargs): # retrieve ALL data array = array_func(*args, **kwargs) coords, dims = _fix_coords_dims(len(bz), array, coords, dims) - attrs = {'bz': bz, 'parent': bz.parent} - return xarray.DataArray(array, coords=coords, dims=dims, name=name, attrs=attrs) + attrs = {"bz": bz, "parent": bz.parent} + return xarray.DataArray( + array, coords=coords, dims=dims, name=name, attrs=attrs + ) return func + # Register dispatched functions apply_dispatch = BrillouinZone.apply apply_dispatch.register("iter", IteratorApply, default=True) @@ -452,18 +518,18 @@ class MonkhorstPackApply(BrillouinZoneApply): @set_module("sisl.physics") class MonkhorstPackParentApply(MonkhorstPackApply): - - def __str__(self, message=''): + def __str__(self, message=""): return _correct_str(super().__str__(), message) def _parse_kwargs(self, wrap, eta=None, eta_key=""): - """ Parse kwargs """ + """Parse kwargs""" bz = self._obj parent = bz.parent if wrap is None: # we always return a wrap def wrap(v, parent=None, k=None, weight=None): return v + else: wrap = allow_kwargs("parent", "k", "weight")(wrap) eta = progressbar(len(bz), f"{bz.__class__.__name__}.{eta_key}", "k", eta) @@ -479,7 +545,7 @@ def __getattr__(self, key): @set_module("sisl.physics") class GridApply(MonkhorstPackParentApply): - """ Calculate on a Grid + """Calculate on a Grid The calculation of values on a grid requires some careful thought before running the calculation as the returned grid may be somewhat difficult @@ -506,16 +572,16 @@ class GridApply(MonkhorstPackParentApply): >>> obj = MonkhorstPack(Hamiltonian, [10, 1, 10]) >>> grid = obj.asgrid().eigh(data_axis=1) """ + def __str__(self, message="grid"): return super().__str__(message) def dispatch(self, method, eta_key="grid"): - """ Dispatch the method by putting values on the grid """ + """Dispatch the method by putting values on the grid""" pool = _pool_procs(self._attrs.get("pool", None)) @wraps(method) def func(*args, wrap=None, eta=None, **kwargs): - data_axis = kwargs.pop("data_axis", None) grid_unit = kwargs.pop("grid_unit", "b") @@ -527,7 +593,9 @@ def func(*args, wrap=None, eta=None, **kwargs): # define the Grid size, etc. diag = mp._diag.copy() if not np.all(mp._displ == 0): - raise SislError(f"{mp.__class__.__name__ } requires the displacement to be 0 for all k-points.") + raise SislError( + f"{mp.__class__.__name__ } requires the displacement to be 0 for all k-points." + ) displ = mp._displ.copy() size = mp._size.copy() steps = size / diag @@ -547,6 +615,7 @@ def func(*args, wrap=None, eta=None, **kwargs): _in_primitive = mp.in_primitive _rint = np.rint _int32 = np.int32 + def k2idx(k): # In case TRS is applied two indices may be returned return _rint((_in_primitive(k) - offset) / steps).astype(_int32) @@ -555,49 +624,70 @@ def k2idx(k): # with i in [0, 1, 2] # Create cell from the reciprocal cell. - if grid_unit == 'b': + if grid_unit == "b": cell = np.diag(mp._size) else: - cell = parent.lattice.rcell * mp._size.reshape(1, -1) / units("Ang", grid_unit) + cell = ( + parent.lattice.rcell + * mp._size.reshape(1, -1) + / units("Ang", grid_unit) + ) # Find the grid origin origin = -(cell * 0.5).sum(0) # Calculate first k-point (to get size and dtype) - v = wrap(method(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0]) + v = wrap( + method(*args, k=k[0], **kwargs), parent=parent, k=k[0], weight=w[0] + ) if data_axis is None: if v.size != 1: - raise SislError(f"{self.__class__.__name__} {func.__name__} requires one value per-kpoint because of the 3D grid values") + raise SislError( + f"{self.__class__.__name__} {func.__name__} requires one value per-kpoint because of the 3D grid values" + ) else: - # Check the weights - weights = mp.grid(diag[data_axis], displ[data_axis], size[data_axis], - centered=mp._centered, trs=trs_axis == data_axis)[1] + weights = mp.grid( + diag[data_axis], + displ[data_axis], + size[data_axis], + centered=mp._centered, + trs=trs_axis == data_axis, + )[1] # Correct the Grid size diag[data_axis] = len(v) # Create the orthogonal cell direction to ensure it is orthogonal # Since array axis is cyclic for negative numbers, we simply do this - cell[data_axis, :] = cross(cell[data_axis-1, :], cell[data_axis-2, :]) + cell[data_axis, :] = cross( + cell[data_axis - 1, :], cell[data_axis - 2, :] + ) # Check whether we should rotate it if cart2spher(cell[data_axis, :])[2] > pi / 4: cell[data_axis, :] *= -1 # Correct cell for the grid if trs_axis >= 0: - origin[trs_axis] = 0. + origin[trs_axis] = 0.0 # Correct offset since we only have the positive halve if mp._diag[trs_axis] % 2 == 0 and not mp._centered: offset[trs_axis] = steps[trs_axis] / 2 else: - offset[trs_axis] = 0. + offset[trs_axis] = 0.0 # Find number of points if trs_axis != data_axis: - diag[trs_axis] = len(mp.grid(diag[trs_axis], displ[trs_axis], size[trs_axis], - centered=mp._centered, trs=True)[1]) + diag[trs_axis] = len( + mp.grid( + diag[trs_axis], + displ[trs_axis], + size[trs_axis], + centered=mp._centered, + trs=True, + )[1] + ) # Create the grid in the reciprocal cell lattice = Lattice(cell, origin=origin) @@ -616,16 +706,27 @@ def k2idx(k): eta.update() if data_axis is None: for i in range(1, len(k)): - grid[k2idx(k[i])] = wrap(method(*args, k=k[i], **kwargs), - parent=parent, k=k[i], weight=w[i]) + grid[k2idx(k[i])] = wrap( + method(*args, k=k[i], **kwargs), + parent=parent, + k=k[i], + weight=w[i], + ) eta.update() else: for i in range(1, len(k)): idx = k2idx(k[i]).tolist() weight = weights[idx[data_axis]] idx[data_axis] = slice(None) - grid[tuple(idx)] = wrap(method(*args, k=k[i], **kwargs), - parent=parent, k=k[i], weight=w[i]) * weight + grid[tuple(idx)] = ( + wrap( + method(*args, k=k[i], **kwargs), + parent=parent, + k=k[i], + weight=w[i], + ) + * weight + ) eta.update() eta.close() return grid diff --git a/src/sisl/physics/_feature.py b/src/sisl/physics/_feature.py index 3971e5794d..b10ae9d0f0 100644 --- a/src/sisl/physics/_feature.py +++ b/src/sisl/physics/_feature.py @@ -8,8 +8,8 @@ __all__ = ["yield_manifolds"] -def yield_manifolds(values, atol: float=0.1, axis: int=-1) -> Iterator[List]: - r""" Yields indices for manifolds along the axis `axis` +def yield_manifolds(values, atol: float = 0.1, axis: int = -1) -> Iterator[List]: + r"""Yields indices for manifolds along the axis `axis` A manifold is found under the criteria that all neighbouring values along `axis` are separated by at least `atol`. @@ -43,7 +43,7 @@ def yield_manifolds(values, atol: float=0.1, axis: int=-1) -> Iterator[List]: # Now calculate the manifold for each of the different directions manifold = [0] for i in range(1, len(v_min)): - if np.all(v_max[i-1] < v_min[i] - atol): + if np.all(v_max[i - 1] < v_min[i] - atol): # we are starting a new manifold yield manifold manifold = [i] diff --git a/src/sisl/physics/bloch.py b/src/sisl/physics/bloch.py index 583e1e5d3a..976872e535 100644 --- a/src/sisl/physics/bloch.py +++ b/src/sisl/physics/bloch.py @@ -17,7 +17,7 @@ from ._bloch import bloch_unfold -__all__ = ['Bloch'] +__all__ = ["Bloch"] @set_module("sisl.physics") @@ -64,30 +64,34 @@ class Bloch: """ def __init__(self, *bloch): - """ Create `Bloch` object """ + """Create `Bloch` object""" self._bloch = _a.arrayi(bloch).ravel() - self._bloch = np.where(self._bloch < 1, 1, self._bloch).astype(np.int32, copy=False) + self._bloch = np.where(self._bloch < 1, 1, self._bloch).astype( + np.int32, copy=False + ) if len(self._bloch) != 3: - raise ValueError(self.__class__.__name__ + ' requires 3 input values') + raise ValueError(self.__class__.__name__ + " requires 3 input values") if np.any(self._bloch < 1): - raise ValueError(self.__class__.__name__ + ' requires all unfoldings to be larger than 0') + raise ValueError( + self.__class__.__name__ + " requires all unfoldings to be larger than 0" + ) def __len__(self): - """ Return unfolded size """ + """Return unfolded size""" return np.prod(self.bloch) def __str__(self): - """ Representation of the Bloch model """ + """Representation of the Bloch model""" B = self._bloch return f"{self.__class__.__name__}{{{B[0]}, {B[1]}, {B[2]}}}" @property def bloch(self): - """ Number of Bloch expansions along each lattice vector """ + """Number of Bloch expansions along each lattice vector""" return self._bloch def unfold_points(self, k): - r""" Return a list of k-points to be evaluated for this objects unfolding + r"""Return a list of k-points to be evaluated for this objects unfolding The k-point `k` is with respect to the unfolded geometry. The return list of `k` points are the k-points required to be sampled in the @@ -116,7 +120,7 @@ def unfold_points(self, k): return unfold.reshape(-1, 3) def __call__(self, func, k, *args, **kwargs): - """ Return a functions return values as the Bloch unfolded equivalent according to this object + """Return a functions return values as the Bloch unfolded equivalent according to this object Calling the `Bloch` object is a shorthand for the manual use of the `Bloch.unfold_points` and `Bloch.unfold` methods. @@ -159,7 +163,7 @@ def __call__(self, func, k, *args, **kwargs): return bloch_unfold(self._bloch, K_unfold, M) def unfold(self, M, k_unfold): - r""" Unfold the matrix list of matrices `M` into a corresponding k-point (unfolding k-points are `k_unfold`) + r"""Unfold the matrix list of matrices `M` into a corresponding k-point (unfolding k-points are `k_unfold`) Parameters ---------- diff --git a/src/sisl/physics/brillouinzone.py b/src/sisl/physics/brillouinzone.py index 6e97ac105e..c1c56fb6da 100644 --- a/src/sisl/physics/brillouinzone.py +++ b/src/sisl/physics/brillouinzone.py @@ -159,7 +159,7 @@ class BrillouinZoneDispatcher(ClassDispatcher): - r""" Loop over all k-points by applying `parent` methods for all k. + r"""Loop over all k-points by applying `parent` methods for all k. This allows potential for running and collecting various computationally heavy methods from a single point on all k-points. @@ -192,7 +192,7 @@ class BrillouinZoneDispatcher(ClassDispatcher): @set_module("sisl.physics") def linspace_bz(bz, stop=None, jumps=None, jump_dk=0.05): - r""" Convert points from a BZ object into a linear spacing of maximum value `stop` + r"""Convert points from a BZ object into a linear spacing of maximum value `stop` Parameters ---------- @@ -217,14 +217,14 @@ def linspace_bz(bz, stop=None, jumps=None, jump_dk=0.05): # calculate vectors between each neighbouring points dcart = np.diff(cart, axis=0, prepend=cart[0].reshape(1, -1)) # calculate distances - dist = (dcart ** 2).sum(1) ** 0.5 + dist = (dcart**2).sum(1) ** 0.5 if jumps is not None: # calculate the total distance total_dist = dist.sum() # Zero out the jumps - dist[jumps] = 0. + dist[jumps] = 0.0 total_dist = dist.sum() # correct jumps dist[jumps] = total_dist * np.asarray(jump_dk) @@ -240,7 +240,7 @@ def linspace_bz(bz, stop=None, jumps=None, jump_dk=0.05): @set_module("sisl.physics") class BrillouinZone: - """ A class to construct Brillouin zone related quantities + """A class to construct Brillouin zone related quantities It takes any object (which has access to cell-vectors) as an argument and can then return the k-points in non-reduced units from reduced units. @@ -279,16 +279,18 @@ def __init__(self, parent, k=None, weight=None): self._k = _a.arrayd(k).reshape(-1, 3) self._w = _a.emptyd(len(k)) if weight is None: - weight = 1. / len(self._k) + weight = 1.0 / len(self._k) self._w[:] = weight - apply = BrillouinZoneDispatcher("apply", - # Do not allow class dispatching - type_dispatcher=None, - obj_getattr=lambda obj, key: getattr(obj.parent, key)) + apply = BrillouinZoneDispatcher( + "apply", + # Do not allow class dispatching + type_dispatcher=None, + obj_getattr=lambda obj, key: getattr(obj.parent, key), + ) def set_parent(self, parent): - """ Update the parent associated to this object + """Update the parent associated to this object Parameters ---------- @@ -304,7 +306,7 @@ def set_parent(self, parent): self.parent = Lattice(parent) def __str__(self): - """ String representation of the BrillouinZone """ + """String representation of the BrillouinZone""" parent = self.parent if isinstance(parent, Lattice): parent = str(parent).replace("\n", "\n ") @@ -313,25 +315,25 @@ def __str__(self): return f"{self.__class__.__name__}{{nk: {len(self)},\n {parent}\n}}" def __getstate__(self): - """ Return dictionary with the current state """ + """Return dictionary with the current state""" return { - 'parent_class': self.parent.__class__, - 'parent': self.parent.__getstate__(), - 'k': self._k.copy(), - 'weight': self._w.copy() + "parent_class": self.parent.__class__, + "parent": self.parent.__getstate__(), + "k": self._k.copy(), + "weight": self._w.copy(), } def __setstate__(self, state): - """ Reset state of the object """ - self._k = state['k'] - self._w = state['weight'] - parent = state['parent_class'].__new__(state['parent_class']) - parent.__setstate__(state['parent']) + """Reset state of the object""" + self._k = state["k"] + self._w = state["weight"] + parent = state["parent_class"].__new__(state["parent_class"]) + parent.__setstate__(state["parent"]) self.set_parent(parent) @staticmethod - def merge(bzs, weight_scale=1., parent=None): - """ Merge several BrillouinZone objects into one + def merge(bzs, weight_scale=1.0, parent=None): + """Merge several BrillouinZone objects into one The merging strategy only stores the new list of k-points and weights. Information retained in the merged objects will not be stored. @@ -359,22 +361,26 @@ def merge(bzs, weight_scale=1., parent=None): # check for lengths (scales cannot be longer!) if len(bzs) < len(weight_scale): - raise ValueError("BrillouinZone.merge requires length of weight_scale to be smaller or equal to " - "the objects.") + raise ValueError( + "BrillouinZone.merge requires length of weight_scale to be smaller or equal to " + "the objects." + ) if parent is None: parent = bzs[0].parent k = [] w = [] - for bz, scale in itertools.zip_longest(bzs, weight_scale, fillvalue=weight_scale[-1]): + for bz, scale in itertools.zip_longest( + bzs, weight_scale, fillvalue=weight_scale[-1] + ): k.append(bz.k) w.append(bz.weight * scale) return BrillouinZone(parent, np.concatenate(k), np.concatenate(w)) def volume(self, ret_dim=False, periodic=None): - """ Calculate the volume of the full Brillouin zone of the parent + """Calculate the volume of the full Brillouin zone of the parent This will return the volume depending on the dimensions of the system. Here the dimensions of the system is determined by how many dimensions @@ -403,7 +409,7 @@ def volume(self, ret_dim=False, periodic=None): periodic = (self.parent.nsc > 1).nonzero()[0] dim = len(periodic) - vol = 0. + vol = 0.0 if dim == 3: vol = self.parent.volume elif dim == 2: @@ -417,7 +423,7 @@ def volume(self, ret_dim=False, periodic=None): @staticmethod def parametrize(parent, func, N, *args, **kwargs): - """ Generate a new `BrillouinZone` object with k-points parameterized via the function `func` in `N` separations + """Generate a new `BrillouinZone` object with k-points parameterized via the function `func` in `N` separations Generator of a parameterized Brillouin zone object that contains a parameterized k-point list. @@ -476,7 +482,7 @@ def parametrize(parent, func, N, *args, **kwargs): @staticmethod def param_circle(parent, N_or_dk, kR, normal, origin, loop=False): - r""" Create a parameterized k-point list where the k-points are generated on a circle around an origin + r"""Create a parameterized k-point list where the k-points are generated on a circle around an origin The generated circle is a perfect circle in the reciprocal space (Cartesian coordinates). To generate a perfect circle in units of the reciprocal lattice vectors one can @@ -523,10 +529,12 @@ def param_circle(parent, N_or_dk, kR, normal, origin, loop=False): N = N_or_dk else: # Calculate the required number of points - N = int(kR ** 2 * pi / N_or_dk + 0.5) + N = int(kR**2 * pi / N_or_dk + 0.5) if N < 2: N = 2 - info('BrillouinZone.param_circle increased the number of circle points to 2.') + info( + "BrillouinZone.param_circle increased the number of circle points to 2." + ) # Conversion object bz = BrillouinZone(parent) @@ -538,13 +546,13 @@ def param_circle(parent, N_or_dk, kR, normal, origin, loop=False): # Generate a preset list of k-points on the unit-circle if loop: - radians = _a.aranged(N) / (N-1) * 2 * np.pi + radians = _a.aranged(N) / (N - 1) * 2 * np.pi else: radians = _a.aranged(N) / N * 2 * np.pi k = _a.emptyd([N, 3]) k[:, 0] = np.cos(radians) k[:, 1] = np.sin(radians) - k[:, 2] = 0. + k[:, 2] = 0.0 # Now generate the rotation _, theta, phi = cart2spher(k_n) @@ -553,7 +561,7 @@ def param_circle(parent, N_or_dk, kR, normal, origin, loop=False): pv /= fnorm(pv) q = Quaternion(phi, pv, rad=True) * Quaternion(theta, [0, 0, 1], rad=True) else: - q = Quaternion(0., [0, 0, k_n[2] / abs(k_n[2])], rad=True) + q = Quaternion(0.0, [0, 0, k_n[2] / abs(k_n[2])], rad=True) # Calculate k-points k = q.rotate(k) @@ -561,13 +569,13 @@ def param_circle(parent, N_or_dk, kR, normal, origin, loop=False): k = bz.toreduced(k + k_o) # The sum of weights is equal to the BZ area - W = np.pi * kR ** 2 + W = np.pi * kR**2 w = np.repeat([W / N], N) return BrillouinZone(parent, k, w) def copy(self, parent=None): - """ Create a copy of this object, optionally changing the parent + """Create a copy of this object, optionally changing the parent Parameters ---------- @@ -583,12 +591,12 @@ def copy(self, parent=None): @property def k(self): - """ A list of all k-points (if available) """ + """A list of all k-points (if available)""" return self._k @property def weight(self): - """ Weight of the k-points in the `BrillouinZone` object """ + """Weight of the k-points in the `BrillouinZone` object""" return self._w @property @@ -600,7 +608,7 @@ def rcell(self): return self.parent.rcell def tocartesian(self, k): - """ Transfer a k-point in reduced coordinates to the Cartesian coordinates + """Transfer a k-point in reduced coordinates to the Cartesian coordinates Parameters ---------- @@ -615,7 +623,7 @@ def tocartesian(self, k): return dot(k, self.rcell) def toreduced(self, k): - """ Transfer a k-point in Cartesian coordinates to the reduced coordinates + """Transfer a k-point in Cartesian coordinates to the reduced coordinates Parameters ---------- @@ -631,7 +639,7 @@ def toreduced(self, k): @staticmethod def in_primitive(k): - """ Move the k-point into the primitive point(s) ]-0.5 ; 0.5] + """Move the k-point into the primitive point(s) ]-0.5 ; 0.5] Parameters ---------- @@ -643,7 +651,7 @@ def in_primitive(k): numpy.ndarray all k-points moved into the primitive cell """ - k = _a.arrayd(k) % 1. + k = _a.arrayd(k) % 1.0 # Ensure that we are in the interval ]-0.5; 0.5] k[k > 0.5] -= 1 @@ -651,7 +659,7 @@ def in_primitive(k): return k def iter(self, ret_weight=False): - """ An iterator for the k-points and (possibly) the weights + """An iterator for the k-points and (possibly) the weights Parameters ---------- @@ -662,7 +670,7 @@ def iter(self, ret_weight=False): ------ kpt : k-point weight : weight of k-point, only if `ret_weight` is true. - """ + """ if ret_weight: for i in range(len(self)): yield self.k[i], self.weight[i] @@ -675,22 +683,23 @@ def __len__(self): return len(self._k) def write(self, sile, *args, **kwargs): - """ Writes k-points to a `~sisl.io.tableSile`. + """Writes k-points to a `~sisl.io.tableSile`. This allows one to pass a `tableSile` or a file-name. """ from sisl.io import tableSile + kw = np.concatenate((self.k, self.weight.reshape(-1, 1)), axis=1) if isinstance(sile, tableSile): sile.write_data(kw.T, *args, **kwargs) else: - with tableSile(sile, 'w') as fh: + with tableSile(sile, "w") as fh: fh.write_data(kw.T, *args, **kwargs) @set_module("sisl.physics") class MonkhorstPack(BrillouinZone): - r""" Create a Monkhorst-Pack grid for the Brillouin zone + r"""Create a Monkhorst-Pack grid for the Brillouin zone Parameters ---------- @@ -722,7 +731,9 @@ class MonkhorstPack(BrillouinZone): >>> MonkhorstPack(lattice, [10, 5, 5], trs=False) # 10 x 5 x 5 (without TRS) """ - def __init__(self, parent, nkpt, displacement=None, size=None, centered=True, trs=True): + def __init__( + self, parent, nkpt, displacement=None, size=None, centered=True, trs=True + ): super().__init__(parent) if isinstance(nkpt, Integral): @@ -732,7 +743,9 @@ def __init__(self, parent, nkpt, displacement=None, size=None, centered=True, tr # Now we have a matrix of k-points if np.any(nkpt - np.diag(np.diag(nkpt)) != 0): - raise NotImplementedError(f"{self.__class__.__name__} with off-diagonal components is not implemented yet") + raise NotImplementedError( + f"{self.__class__.__name__} with off-diagonal components is not implemented yet" + ) if displacement is None: displacement = np.zeros(3, np.float64) @@ -753,15 +766,17 @@ def __init__(self, parent, nkpt, displacement=None, size=None, centered=True, tr # Retrieve the diagonal number of values Dn = np.diag(nkpt).astype(np.int32) if np.any(Dn) == 0: - raise ValueError(f'{self.__class__.__name__} *must* be initialized with ' - 'diagonal elements different from 0.') + raise ValueError( + f"{self.__class__.__name__} *must* be initialized with " + "diagonal elements different from 0." + ) i_trs = -1 if trs: # Figure out which direction to TRS nmax = 0 for i in [0, 1, 2]: - if displacement[i] in [0., 0.5] and Dn[i] > nmax: + if displacement[i] in [0.0, 0.5] and Dn[i] > nmax: nmax = Dn[i] i_trs = i if nmax == 1: @@ -772,7 +787,10 @@ def __init__(self, parent, nkpt, displacement=None, size=None, centered=True, tr i_trs = np.argmax(Dn) # Calculate k-points and weights along all directions - kw = [self.grid(Dn[i], displacement[i], size[i], centered, i == i_trs) for i in (0, 1, 2)] + kw = [ + self.grid(Dn[i], displacement[i], size[i], centered, i == i_trs) + for i in (0, 1, 2) + ] # Now figure out if we have a 0 point along the TRS direction if trs: @@ -789,16 +807,20 @@ def __init__(self, parent, nkpt, displacement=None, size=None, centered=True, tr # Note for a 100 x 100 k-point sampling this will produce # a 100 ^ 4 matrix ~ 93 MB # For larger k-point samplings this is probably not so good (300x300 -> 7.5 GB) - k_dup = k_dup.reshape(k1.size, k2.size, 1, 1, 2) + k_dup.reshape(1, 1, k1.size, k2.size, 2) - k_dup = ((k_dup[..., 0] ** 2 + k_dup[..., 1] ** 2) ** 0.5 < 1e-10).nonzero() + k_dup = k_dup.reshape(k1.size, k2.size, 1, 1, 2) + k_dup.reshape( + 1, 1, k1.size, k2.size, 2 + ) + k_dup = ( + (k_dup[..., 0] ** 2 + k_dup[..., 1] ** 2) ** 0.5 < 1e-10 + ).nonzero() # At this point we have found all duplicate points, to only take one # half of the points we only take the lower half # Also, the Gamma point is *always* zero, so we shouldn't do <=! # Now check the case where one of the directions is (only) the Gamma-point - if kw[ik1][0].size == 1 and kw[ik1][0][0] == 0.: + if kw[ik1][0].size == 1 and kw[ik1][0][0] == 0.0: # We keep all indices for the ik1 direction (since it is the Gamma-point! rel = (k_dup[1] > k_dup[3]).nonzero()[0] - elif kw[ik2][0].size == 1 and kw[ik2][0][0] == 0.: + elif kw[ik2][0].size == 1 and kw[ik2][0][0] == 0.0: # We keep all indices for the ik2 direction (since it is the Gamma-point! rel = (k_dup[0] > k_dup[2]).nonzero()[0] else: @@ -842,50 +864,58 @@ def __init__(self, parent, nkpt, displacement=None, size=None, centered=True, tr # Store information regarding size and diagonal elements # This information is basically only necessary when # we want to replace special k-points - self._diag = Dn # vector - self._displ = displacement # vector - self._size = size # vector + self._diag = Dn # vector + self._displ = displacement # vector + self._size = size # vector self._centered = centered self._trs = i_trs @property def displacement(self): - """ Displacement for this Monkhorst-Pack grid """ + """Displacement for this Monkhorst-Pack grid""" return self._displ def __str__(self): - """ String representation of `MonkhorstPack` """ + """String representation of `MonkhorstPack`""" if isinstance(self.parent, Lattice): p = self.parent else: p = self.parent.lattice - return ('{cls}{{nk: {nk:d}, size: [{size[0]:.5f} {size[1]:.5f} {size[0]:.5f}], trs: {trs},' - '\n diagonal: [{diag[0]:d} {diag[1]:d} {diag[2]:d}], displacement: [{disp[0]:.5f} {disp[1]:.5f} {disp[2]:.5f}],' - '\n {lattice}\n}}').format(cls=self.__class__.__name__, nk=len(self), - size=self._size, trs={0: 'A', 1: 'B', 2: 'C'}.get(self._trs, 'no'), - diag=self._diag, disp=self._displ, lattice=str(p).replace('\n', '\n ')) + return ( + "{cls}{{nk: {nk:d}, size: [{size[0]:.5f} {size[1]:.5f} {size[0]:.5f}], trs: {trs}," + "\n diagonal: [{diag[0]:d} {diag[1]:d} {diag[2]:d}], displacement: [{disp[0]:.5f} {disp[1]:.5f} {disp[2]:.5f}]," + "\n {lattice}\n}}" + ).format( + cls=self.__class__.__name__, + nk=len(self), + size=self._size, + trs={0: "A", 1: "B", 2: "C"}.get(self._trs, "no"), + diag=self._diag, + disp=self._displ, + lattice=str(p).replace("\n", "\n "), + ) def __getstate__(self): - """ Return dictionary with the current state """ + """Return dictionary with the current state""" state = super().__getstate__() - state['diag'] = self._diag - state['displ'] = self._displ - state['size'] = self._size - state['centered'] = self._centered - state['trs'] = self._trs + state["diag"] = self._diag + state["displ"] = self._displ + state["size"] = self._size + state["centered"] = self._centered + state["trs"] = self._trs return state def __setstate__(self, state): - """ Reset state of the object """ + """Reset state of the object""" super().__setstate__(state) - self._diag = state['diag'] - self._displ = state['displ'] - self._size = state['size'] - self._centered = state['centered'] - self._trs = state['trs'] + self._diag = state["diag"] + self._displ = state["displ"] + self._size = state["size"] + self._centered = state["centered"] + self._trs = state["trs"] def copy(self, parent=None): - """ Create a copy of this object, optionally changing the parent + """Create a copy of this object, optionally changing the parent Parameters ---------- @@ -894,15 +924,17 @@ def copy(self, parent=None): """ if parent is None: parent = self.parent - bz = self.__class__(parent, self._diag, self._displ, self._size, self._centered, self._trs >= 0) + bz = self.__class__( + parent, self._diag, self._displ, self._size, self._centered, self._trs >= 0 + ) # this is required due to replace calls bz._k = self._k.copy() bz._w = self._w.copy() return bz @classmethod - def grid(cls, n, displ=0., size=1., centered=True, trs=False): - r""" Create a grid of `n` points with an offset of `displ` and sampling `size` around `displ` + def grid(cls, n, displ=0.0, size=1.0, centered=True, trs=False): + r"""Create a grid of `n` points with an offset of `displ` and sampling `size` around `displ` The :math:`k`-points are :math:`\Gamma` centered. @@ -927,16 +959,16 @@ def grid(cls, n, displ=0., size=1., centered=True, trs=False): weights for the k-points """ # First ensure that displ is in the Brillouin - displ = displ % 1. + displ = displ % 1.0 if displ > 0.5: - displ -= 1. + displ -= 1.0 if displ < -0.5: - displ += 1. + displ += 1.0 # Centered _only_ has effect IFF # displ == 0. and size == 1 # Otherwise we resort to other schemes - if displ != 0. or size != 1.: + if displ != 0.0 or size != 1.0: centered = False # size *per k-point* @@ -957,7 +989,7 @@ def grid(cls, n, displ=0., size=1., centered=True, trs=False): w = _a.fulld(n, dsize) # Check for TRS points - if trs and np.any(k < 0.): + if trs and np.any(k < 0.0): # Make all positive to remove the double conting terms k_pos = np.fabs(k) @@ -989,7 +1021,7 @@ def grid(cls, n, displ=0., size=1., centered=True, trs=False): return k, w def replace(self, k, mp, displacement=False, as_index=False, check_vol=True): - r""" Replace a k-point with a new set of k-points from a Monkhorst-Pack grid + r"""Replace a k-point with a new set of k-points from a Monkhorst-Pack grid This method tries to replace an area corresponding to `mp.size` around the k-point `k` such that the k-points are replaced. @@ -1063,8 +1095,10 @@ def replace(self, k, mp, displacement=False, as_index=False, check_vol=True): # k-point volumes. k_int = mp._size / k_vol if not np.allclose(np.rint(k_int), k_int): - raise SislError(f"{self.__class__.__name__}.reduce could not replace k-point, BZ " - "volume replaced is not equivalent to the inherent k-point volume.") + raise SislError( + f"{self.__class__.__name__}.reduce could not replace k-point, BZ " + "volume replaced is not equivalent to the inherent k-point volume." + ) # the size of the k-points that will be added s_size2 = self._size / 2 @@ -1081,18 +1115,24 @@ def replace(self, k, mp, displacement=False, as_index=False, check_vol=True): else: # find k-points in batches of 200 MB k = self.in_primitive(k).reshape(-1, 3) - idx = batched_indices(self.k, k, atol=dk, batch_size=200, - diff_func=self.in_primitive)[0] - if self._trs >= 0: # TRS along a given axis, we can search the mirrored values - idx2 = batched_indices(self.k, -k, atol=dk, batch_size=200, - diff_func=self.in_primitive)[0] + idx = batched_indices( + self.k, k, atol=dk, batch_size=200, diff_func=self.in_primitive + )[0] + if ( + self._trs >= 0 + ): # TRS along a given axis, we can search the mirrored values + idx2 = batched_indices( + self.k, -k, atol=dk, batch_size=200, diff_func=self.in_primitive + )[0] idx = np.concatenate((idx, idx2)) # we may find 2 indices for gamm-point in this case... not useful idx = np.unique(idx) if len(idx) == 0: - raise SislError(f"{self.__class__.__name__}.reduce found no k-points to replace. " - f"Searched with precision: {dk.ravel()}") + raise SislError( + f"{self.__class__.__name__}.reduce found no k-points to replace. " + f"Searched with precision: {dk.ravel()}" + ) # Idea of fast replacements is attributed @ahkole in #454, but the resulting code needed some # changes since that code was not stable againts *wrong input*, i.e. k=[0, 0, 0] @@ -1122,41 +1162,54 @@ def replace(self, k, mp, displacement=False, as_index=False, check_vol=True): replace_weight = mp.weight.sum() * displ_nk atol = min(total_weight, replace_weight) * 1e-4 if abs(total_weight - replace_weight) < atol: - weight_factor = 1. + weight_factor = 1.0 elif abs(total_weight - replace_weight * 2) < atol: - weight_factor = 2. + weight_factor = 2.0 if self._trs < 0: - info(f"{self.__class__.__name__}.reduce assumes that the replaced k-point has double weights.") + info( + f"{self.__class__.__name__}.reduce assumes that the replaced k-point has double weights." + ) else: - #print("k-point to replace: ", k.ravel()) - #print("delta-k: ", dk.ravel()) - #print("Found k-indices that will be replaced:") - #print(idx) - #print("k-points replaced:") - #print(self.k[idx, :]) - #print("weights replaced:") - #print(self.weight[idx]) - #print(self.weight.min(), self.weight.max()) - #print(mp.weight.min(), mp.weight.max()) - #print("Summed weights vs. replaced summed weights: ") - #print(total_weight, replace_weight) - #print(mp) - raise SislError(f"{self.__class__.__name__}.reduce found inconsistent replacement weights " - f"self={total_weight} vs. mp={replace_weight}. " - f"Replacement indices: {idx}.") + # print("k-point to replace: ", k.ravel()) + # print("delta-k: ", dk.ravel()) + # print("Found k-indices that will be replaced:") + # print(idx) + # print("k-points replaced:") + # print(self.k[idx, :]) + # print("weights replaced:") + # print(self.weight[idx]) + # print(self.weight.min(), self.weight.max()) + # print(mp.weight.min(), mp.weight.max()) + # print("Summed weights vs. replaced summed weights: ") + # print(total_weight, replace_weight) + # print(mp) + raise SislError( + f"{self.__class__.__name__}.reduce found inconsistent replacement weights " + f"self={total_weight} vs. mp={replace_weight}. " + f"Replacement indices: {idx}." + ) # delete and append new k-points and weights if displacement is None: self._k = np.concatenate((np.delete(self._k, idx, axis=0), mp._k), axis=0) else: - self._k = np.concatenate((np.delete(self._k, idx, axis=0), - self.in_primitive(mp.k + displacement.reshape(-1, 1, 3)).reshape(-1, 3)), axis=0) - self._w = np.concatenate((np.delete(self._w, idx), np.tile(mp._w * weight_factor, displ_nk))) + self._k = np.concatenate( + ( + np.delete(self._k, idx, axis=0), + self.in_primitive(mp.k + displacement.reshape(-1, 1, 3)).reshape( + -1, 3 + ), + ), + axis=0, + ) + self._w = np.concatenate( + (np.delete(self._w, idx), np.tile(mp._w * weight_factor, displ_nk)) + ) @set_module("sisl.physics") class BandStructure(BrillouinZone): - """ Create a path in the Brillouin zone for plotting band-structures etc. + """Create a path in the Brillouin zone for plotting band-structures etc. Parameters ---------- @@ -1197,10 +1250,14 @@ class BandStructure(BrillouinZone): >>> bs = BandStructure(lattice, [[0, 0, 0], [0, 0.5, 0], None, [0.5, 0, 0], [0.5, 0.5, 0]], 200) """ - @deprecate_argument("name", "names", "argument 'name' has been deprecated in favor of 'names', please update your code.", - "0.15.0") + @deprecate_argument( + "name", + "names", + "argument 'name' has been deprecated in favor of 'names', please update your code.", + "0.15.0", + ) def __init__(self, parent, *args, **kwargs): - #points, divisions, names=None): + # points, divisions, names=None): super().__init__(parent) points = kwargs.pop("points", None) @@ -1215,7 +1272,9 @@ def __init__(self, parent, *args, **kwargs): if len(args) > 0: divisions, *args = args else: - raise ValueError(f"{self.__class__.__name__} 'divisions' argument missing") + raise ValueError( + f"{self.__class__.__name__} 'divisions' argument missing" + ) names = kwargs.pop("names", None) if names is None: @@ -1223,14 +1282,17 @@ def __init__(self, parent, *args, **kwargs): names, *args = args if len(args) > 0: - raise ValueError(f"{self.__class__.__name__} unknown arguments after parsing 'points', 'divisions' and 'names': {args}") + raise ValueError( + f"{self.__class__.__name__} unknown arguments after parsing 'points', 'divisions' and 'names': {args}" + ) # Store empty split size self._jump_dk = np.asarray(kwargs.pop("jump_dk", 0.05)) if len(kwargs) > 0: - raise ValueError(f"{self.__class__.__name__} unknown keyword arguments after parsing [points, divisions, names, jump_dk]: {list(kwargs.keys())}") - + raise ValueError( + f"{self.__class__.__name__} unknown keyword arguments after parsing [points, divisions, names, jump_dk]: {list(kwargs.keys())}" + ) # Copy over points # Check if any of the points is None or has length 0 @@ -1254,7 +1316,9 @@ def is_empty(ix): jump_idx = jump_idx[1:] if self._jump_dk.size > 1 and jump_idx.size != self._jump_dk.size: - raise ValueError(f"{self.__class__.__name__} got inconsistent argument lengths (jump_dk does not match jumps in points)") + raise ValueError( + f"{self.__class__.__name__} got inconsistent argument lengths (jump_dk does not match jumps in points)" + ) # The jump-idx is equal to using np.split(self.points, jump_idx) # which then returns continuous sections @@ -1265,25 +1329,28 @@ def is_empty(ix): # If the array has fewer points we try and determine if self.points.shape[1] < 3: if self.points.shape[1] != np.sum(self.parent.nsc > 1): - raise ValueError('Could not determine the non-periodic direction') + raise ValueError("Could not determine the non-periodic direction") # fix the points where there are no periodicity for i in (0, 1, 2): if self.parent.nsc[i] == 1: - self.points = np.insert(self.points, i, 0., axis=1) + self.points = np.insert(self.points, i, 0.0, axis=1) # Ensure the shape is correct self.points.shape = (-1, 3) # Now figure out what to do with the divisions if isinstance(divisions, Integral): - if divisions < len(self.points): - raise ValueError(f"Can not evenly split {len(self.points)} points into {divisions} divisions, ensure division>=len(points)") + raise ValueError( + f"Can not evenly split {len(self.points)} points into {divisions} divisions, ensure division>=len(points)" + ) # Get length between different k-points with a total length # of division - dists = np.diff(linspace_bz(self.tocartesian(self.points), jumps=jump_idx, jump_dk=0.)) + dists = np.diff( + linspace_bz(self.tocartesian(self.points), jumps=jump_idx, jump_dk=0.0) + ) # Get floating point divisions divs_r = dists * divisions / dists.sum() @@ -1292,7 +1359,7 @@ def is_empty(ix): # ensure at least 1 point along each division # 1 division means only the starting point divs[divs == 0] = 1 - divs[jump_idx-1] = 1 + divs[jump_idx - 1] = 1 divs_sum = divs.sum() while divs_sum != divisions - 1: # only check indices where divs > 1 @@ -1307,23 +1374,29 @@ def is_empty(ix): divisions = divs[:] elif len(divisions) + 1 != len(self.points): - raise ValueError(f"inconsistent number of elements in 'points' and 'divisions' argument. One less 'divisions' elements.") + raise ValueError( + f"inconsistent number of elements in 'points' and 'divisions' argument. One less 'divisions' elements." + ) self.divisions = _a.arrayi(divisions).ravel() if names is None: - self.names = 'ABCDEFGHIJKLMNOPQRSTUVXYZ'[:len(self.points)] + self.names = "ABCDEFGHIJKLMNOPQRSTUVXYZ"[: len(self.points)] else: self.names = names if len(self.names) != len(self.points): - raise ValueError(f"inconsistent number of elements in 'points' and 'names' argument") + raise ValueError( + f"inconsistent number of elements in 'points' and 'names' argument" + ) # Calculate points dpoint = np.diff(self.points, axis=0) k = _a.emptyd([self.divisions.sum() + 1, 3]) i = 0 for ik, (divs, dk) in enumerate(zip(self.divisions, dpoint)): - k[i:i+divs, :] = self.points[ik] + dk * _a.aranged(divs).reshape(-1, 1) / divs + k[i : i + divs, :] = ( + self.points[ik] + dk * _a.aranged(divs).reshape(-1, 1) / divs + ) i += divs k[-1] = self.points[-1] # sanity check that should always be obeyed @@ -1333,7 +1406,7 @@ def is_empty(ix): self._w = _a.fulld(len(self.k), 1 / len(self.k)) def copy(self, parent=None): - """ Create a copy of this object, optionally changing the parent + """Create a copy of this object, optionally changing the parent Parameters ---------- @@ -1342,30 +1415,32 @@ def copy(self, parent=None): """ if parent is None: parent = self.parent - bz = self.__class__(parent, self.points, self.divisions, self.names, jump_dk=self._jump_dk) + bz = self.__class__( + parent, self.points, self.divisions, self.names, jump_dk=self._jump_dk + ) return bz def __getstate__(self): - """ Return dictionary with the current state """ + """Return dictionary with the current state""" state = super().__getstate__() - state['points'] = self.points.copy() - state['divisions'] = self.divisions.copy() - state['jump_idx'] = self._jump_idx.copy() - state['names'] = list(self.names) - state['jump_dk'] = self._jump_dk + state["points"] = self.points.copy() + state["divisions"] = self.divisions.copy() + state["jump_idx"] = self._jump_idx.copy() + state["names"] = list(self.names) + state["jump_dk"] = self._jump_dk return state def __setstate__(self, state): - """ Reset state of the object """ + """Reset state of the object""" super().__setstate__(state) - self.points = state['points'] - self.divisions = state['divisions'] - self.names = state['names'] - self._jump_dk = state['jump_dk'] - self._jump_idx = state['jump_idx'] + self.points = state["points"] + self.divisions = state["divisions"] + self.names = state["names"] + self._jump_dk = state["jump_dk"] + self._jump_idx = state["jump_idx"] def insert_jump(self, *arrays, value=np.nan): - """ Return a copy of `arrays` filled with `value` at indices of discontinuity jumps + """Return a copy of `arrays` filled with `value` at indices of discontinuity jumps Arrays with `value` in jumps is easier to plot since those lines will be naturally discontinued. For band structures without discontinuity jumps in the Brillouin zone the `arrays` will @@ -1403,7 +1478,8 @@ def insert_jump(self, *arrays, value=np.nan): return arrays nk = len(self) - full_jumps = np.cumsum(self.divisions)[self._jump_idx-1] + full_jumps = np.cumsum(self.divisions)[self._jump_idx - 1] + def _insert(array): array = np.asarray(array) # ensure dtype is equivalent as input array @@ -1421,7 +1497,7 @@ def _insert(array): return arrays def lineartick(self): - """ The tick-marks corresponding to the linear-k values + """The tick-marks corresponding to the linear-k values Returns ------- @@ -1435,7 +1511,7 @@ def lineartick(self): return self.lineark(True)[1:3] def tolinear(self, k, ret_index=False, tol=1e-4): - """ Convert a k-point into the equivalent linear k-point via the distance + """Convert a k-point into the equivalent linear k-point via the distance Finds the index of the k-point in `self.k` that is closests to `k`. The returned value is then the equivalent index in `lineark`. @@ -1454,7 +1530,7 @@ def tolinear(self, k, ret_index=False, tol=1e-4): The tolerance is in units 1/Ang. """ # Faster than to do sqrt all the time - tol = tol ** 2 + tol = tol**2 # first convert to the cartesian coordinates (for proper distances) ks = self.tocartesian(np.atleast_2d(k)) kk = self.tocartesian(self.k) @@ -1464,7 +1540,9 @@ def find(k): dist = ((kk - k) ** 2).sum(-1) idx = np.argmin(dist) if dist[idx] > tol: - warn(f"{self.__class__.__name__}.tolinear could not find a k-point within given tolerance ({self.toreduced(k)})") + warn( + f"{self.__class__.__name__}.tolinear could not find a k-point within given tolerance ({self.toreduced(k)})" + ) return idx idxs = [find(k) for k in ks] @@ -1473,7 +1551,7 @@ def find(k): return self.lineark()[idxs] def lineark(self, ticks=False): - """ A 1D array which corresponds to the delta-k values of the path + """A 1D array which corresponds to the delta-k values of the path This is mainly meant for plotting but may be useful for finding out distances in the reciprocal lattice. @@ -1514,7 +1592,9 @@ def lineark(self, ticks=False): cum_divs = np.cumsum(self.divisions) # Calculate points # First we also need to calculate the jumps - dK = linspace_bz(self, jumps=cum_divs[self._jump_idx-1], jump_dk=self._jump_dk) + dK = linspace_bz( + self, jumps=cum_divs[self._jump_idx - 1], jump_dk=self._jump_dk + ) # Get label tick, in case self.names is a single string 'ABCD' if ticks: diff --git a/src/sisl/physics/densitymatrix.py b/src/sisl/physics/densitymatrix.py index 99013f255d..62d2716c36 100644 --- a/src/sisl/physics/densitymatrix.py +++ b/src/sisl/physics/densitymatrix.py @@ -27,9 +27,8 @@ class _densitymatrix(SparseOrbitalBZSpin): - def spin_rotate(self, angles, rad=False): - r""" Rotates spin-boxes by fixed angles around the :math:`x`, :math:`y` and :math:`z` axis, respectively. + r"""Rotates spin-boxes by fixed angles around the :math:`x`, :math:`y` and :math:`z` axis, respectively. The angles are with respect to each spin-boxes initial angle. One should use `spin_align` to fix all angles along a specific direction. @@ -62,21 +61,17 @@ def spin_rotate(self, angles, rad=False): def cos_sin(a): return m.cos(a), m.sin(a) + calpha, salpha = cos_sin(angles[0]) cbeta, sbeta = cos_sin(angles[1]) cgamma, sgamma = cos_sin(angles[2]) del cos_sin # define rotation matrix - R = (np.array([[cgamma, -sgamma, 0], - [sgamma, cgamma, 0], - [0, 0, 1]]) - .dot([[cbeta, 0, sbeta], - [0, 1, 0], - [-sbeta, 0, cbeta]]) - .dot([[1, 0, 0], - [0, calpha, -salpha], - [0, salpha, calpha]]) + R = ( + np.array([[cgamma, -sgamma, 0], [sgamma, cgamma, 0], [0, 0, 1]]) + .dot([[cbeta, 0, sbeta], [0, 1, 0], [-sbeta, 0, cbeta]]) + .dot([[1, 0, 0], [0, calpha, -salpha], [0, salpha, calpha]]) ) if self.spin.is_noncolinear: @@ -85,7 +80,7 @@ def cos_sin(a): D = self._csr._D Q = (D[:, 0] + D[:, 1]) * 0.5 A[:, 0] = 2 * D[:, 2] - A[:, 1] = - 2 * D[:, 3] + A[:, 1] = -2 * D[:, 3] A[:, 2] = D[:, 0] - D[:, 1] A = R.dot(A.T).T * 0.5 @@ -112,7 +107,7 @@ def cos_sin(a): A[:, :, 2] = (D[:, 0] - D[:, 1]).reshape(-1, 1) A[:, 0, 0] = 2 * D[:, 2] A[:, 1, 0] = 2 * D[:, 6] - A[:, 0, 1] = - 2 * D[:, 3] + A[:, 0, 1] = -2 * D[:, 3] A[:, 1, 1] = 2 * D[:, 7] A = R.dot(A.reshape(-1, 3).T).T.reshape(-1, 2, 3) * 0.5 @@ -122,29 +117,38 @@ def cos_sin(a): D[:, 0] = Q + A[:, :, 2].sum(1) * 0.5 D[:, 1] = Q - A[:, :, 2].sum(1) * 0.5 D[:, 2] = A[:, 0, 0] - D[:, 3] = - A[:, 0, 1] + D[:, 3] = -A[:, 0, 1] # 4 and 5 are diagonal imaginary part (un-changed) # Since we copy, we don't need to do anything - #D[:, 4] = - #D[:, 5] = + # D[:, 4] = + # D[:, 5] = D[:, 6] = A[:, 1, 0] D[:, 7] = A[:, 1, 1] elif self.spin.is_polarized: + def close(a, v): return abs(abs(a) - v) < np.pi / 1080 # figure out if this is only rotating 180 for x or y - if close(angles[0], np.pi) and close(angles[1], 0) or \ - close(angles[0], 0) and close(angles[1], np.pi): + if ( + close(angles[0], np.pi) + and close(angles[1], 0) + or close(angles[0], 0) + and close(angles[1], np.pi) + ): # flip spin out = self.copy() out._csr._D[:, [0, 1]] = out._csr._D[:, [1, 0]] else: spin = Spin("nc", dtype=self.dtype) - out = self.__class__(self.geometry, dtype=self.dtype, spin=spin, - orthogonal=self.orthogonal) + out = self.__class__( + self.geometry, + dtype=self.dtype, + spin=spin, + orthogonal=self.orthogonal, + ) out._csr.ptr[:] = self._csr.ptr[:] out._csr.ncol[:] = self._csr.ncol[:] out._csr.col = self._csr.col.copy() @@ -159,12 +163,14 @@ def close(a, v): out = out.spin_rotate(angles, rad=True) else: - raise ValueError(f"{self.__class__.__name__}.spin_rotate requires a matrix with some spin configuration, not an unpolarized matrix.") + raise ValueError( + f"{self.__class__.__name__}.spin_rotate requires a matrix with some spin configuration, not an unpolarized matrix." + ) return out def spin_align(self, vec): - r""" Aligns *all* spin along the vector `vec` + r"""Aligns *all* spin along the vector `vec` In case the matrix is polarized and `vec` is not aligned at the z-axis, the returned matrix will be a non-collinear spin configuration. @@ -185,7 +191,7 @@ def spin_align(self, vec): """ vec = _a.asarrayd(vec) # normalize vector - vec = vec / (vec ** 2).sum() ** 0.5 + vec = vec / (vec**2).sum() ** 0.5 if self.spin.is_noncolinear: A = np.empty([len(self._csr._D), 3], dtype=self.dtype) @@ -193,13 +199,14 @@ def spin_align(self, vec): D = self._csr._D Q = (D[:, 0] + D[:, 1]) * 0.5 A[:, 0] = 2 * D[:, 2] - A[:, 1] = - 2 * D[:, 3] + A[:, 1] = -2 * D[:, 3] A[:, 2] = D[:, 0] - D[:, 1] # align with vector # add factor 1/2 here (instead when unwrapping) - A[:, :] = 0.5 * vec.reshape(1, 3) * (np.sum(A ** 2, axis=1) - .reshape(-1, 1)) ** 0.5 + A[:, :] = ( + 0.5 * vec.reshape(1, 3) * (np.sum(A**2, axis=1).reshape(-1, 1)) ** 0.5 + ) out = self.copy() D = out._csr._D @@ -222,33 +229,40 @@ def spin_align(self, vec): Q = (D[:, 0] + D[:, 1]) * 0.5 A[:, :, 2] = (D[:, 0] - D[:, 1]).reshape(-1, 1) A[:, 0, 0] = 2 * D[:, 2] - A[:, 0, 1] = - 2 * D[:, 3] + A[:, 0, 1] = -2 * D[:, 3] A[:, 1, 0] = 2 * D[:, 6] A[:, 1, 1] = 2 * D[:, 7] # align with vector # add factor 1/2 here (instead when unwrapping) - A[:, :, :] = 0.5 * vec.reshape(1, 1, 3) * (np.sum(A ** 2, axis=2) - .reshape(-1, 2, 1)) ** 0.5 + A[:, :, :] = ( + 0.5 + * vec.reshape(1, 1, 3) + * (np.sum(A**2, axis=2).reshape(-1, 2, 1)) ** 0.5 + ) out = self.copy() D = out._csr._D D[:, 0] = Q + A[:, :, 2].sum(1) * 0.5 D[:, 1] = Q - A[:, :, 2].sum(1) * 0.5 D[:, 2] = A[:, 0, 0] - D[:, 3] = - A[:, 0, 1] + D[:, 3] = -A[:, 0, 1] # 4 and 5 are diagonal imaginary part (un-changed) # Since we copy, we don't need to do anything - #D[:, 4] = - #D[:, 5] = + # D[:, 4] = + # D[:, 5] = D[:, 6] = A[:, 1, 0] D[:, 7] = A[:, 1, 1] elif self.spin.is_polarized: if abs(vec.sum() - vec[2]) > 1e-6: spin = Spin("nc", dtype=self.dtype) - out = self.__class__(self.geometry, dtype=self.dtype, spin=spin, - orthogonal=self.orthogonal) + out = self.__class__( + self.geometry, + dtype=self.dtype, + spin=spin, + orthogonal=self.orthogonal, + ) out._csr.ptr[:] = self._csr.ptr[:] out._csr.ncol[:] = self._csr.ncol[:] out._csr.col = self._csr.col.copy() @@ -270,11 +284,13 @@ def spin_align(self, vec): out = self.copy() else: - raise ValueError(f"{self.__class__.__name__}.spin_align requires a matrix with some spin configuration, not an unpolarized matrix.") + raise ValueError( + f"{self.__class__.__name__}.spin_align requires a matrix with some spin configuration, not an unpolarized matrix." + ) return out - def mulliken(self, projection='orbital'): + def mulliken(self, projection="orbital"): r""" Calculate Mulliken charges from the density matrix In the following :math:`\nu` and :math:`\mu` are orbital indices. @@ -319,12 +335,13 @@ def mulliken(self, projection='orbital'): if `projection` does not contain matrix, otherwise ``[spin, no]``, for polarized spin is [T, Sz] and for non-colinear spin is [T, Sx, Sy, Sz] """ + def _convert(M): - """ Converts a non-colinear DM from [11, 22, Re(12), Im(12)] -> [T, Sx, Sy, Sz] """ + """Converts a non-colinear DM from [11, 22, Re(12), Im(12)] -> [T, Sx, Sy, Sz]""" if M.shape[0] == 8: # We need to calculate the corresponding values M[2] = 0.5 * (M[2] + M[6]) - M[3] = 0.5 * (M[3] - M[7]) # sign change again below + M[3] = 0.5 * (M[3] - M[7]) # sign change again below M = M[:4] elif M.shape[0] == 2: # necessary to not overwrite data @@ -339,7 +356,7 @@ def _convert(M): m[0] = M[0] + M[1] m[3] = M[0] - M[1] m[1] = 2 * M[2] - m[2] = - 2 * M[3] + m[2] = -2 * M[3] else: return M return m @@ -347,7 +364,9 @@ def _convert(M): if "orbital" == projection: # Orbital Mulliken population if self.orthogonal: - D = np.array([self._csr.tocsr(i).diagonal() for i in range(self.shape[2])]) + D = np.array( + [self._csr.tocsr(i).diagonal() for i in range(self.shape[2])] + ) else: D = self._csr.copy(range(self.shape[2] - 1)) D._D *= self._csr._D[:, -1].reshape(-1, 1) @@ -358,7 +377,9 @@ def _convert(M): elif "atom" == projection: # Atomic Mulliken population if self.orthogonal: - D = np.array([self._csr.tocsr(i).diagonal() for i in range(self.shape[2])]) + D = np.array( + [self._csr.tocsr(i).diagonal() for i in range(self.shape[2])] + ) else: D = self._csr.copy(range(self.shape[2] - 1)) D._D *= self._csr._D[:, -1].reshape(-1, 1) @@ -372,10 +393,12 @@ def _convert(M): return _convert(M) - raise NotImplementedError(f"{self.__class__.__name__}.mulliken only allows projection [orbital, atom]") + raise NotImplementedError( + f"{self.__class__.__name__}.mulliken only allows projection [orbital, atom]" + ) def density(self, grid, spinor=None, tol=1e-7, eta=None): - r""" Expand the density matrix to the charge density on a grid + r"""Expand the density matrix to the charge density on a grid This routine calculates the real-space density components on a specified grid. @@ -420,8 +443,10 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): # Otherwise we raise an ImportError unique([[0, 1], [2, 3]], axis=0) except Exception: - raise NotImplementedError(f"{self.__class__.__name__}.density requires numpy >= 1.13, either update " - "numpy or do not use this function!") + raise NotImplementedError( + f"{self.__class__.__name__}.density requires numpy >= 1.13, either update " + "numpy or do not use this function!" + ) geometry = self.geometry # Check that the atomic coordinates, really are all within the intrinsic supercell. @@ -441,7 +466,7 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): # In the following we don't care about division # So 1) save error state, 2) turn off divide by 0, 3) calculate, 4) turn on old error state - old_err = np.seterr(divide='ignore', invalid='ignore') + old_err = np.seterr(divide="ignore", invalid="ignore") # Placeholder for the resulting coefficients DM = None @@ -452,7 +477,9 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): else: spinor = _a.arrayz(spinor) if spinor.size != 4 or spinor.ndim != 2: - raise ValueError(f"{self.__class__.__name__}.density with NC/SO spin, requires a 2x2 matrix.") + raise ValueError( + f"{self.__class__.__name__}.density with NC/SO spin, requires a 2x2 matrix." + ) DM = _a.emptyz([self.nnz, 2, 2]) idx = _a.array_arange(csr.ptr[:-1], n=csr.ncol) @@ -479,14 +506,16 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): elif isinstance(spinor, Integral): # extract the provided spin-polarization s = _a.zerosd(2) - s[spinor] = 1. + s[spinor] = 1.0 spinor = s else: spinor = _a.arrayd(spinor) if spinor.size != 2 or spinor.ndim != 1: - raise ValueError(f"{self.__class__.__name__}.density with polarized spin, requires spinor " - "argument as an integer, or a vector of length 2") + raise ValueError( + f"{self.__class__.__name__}.density with polarized spin, requires spinor " + "argument as an integer, or a vector of length 2" + ) idx = _a.array_arange(csr.ptr[:-1], n=csr.ncol) DM = csr._D[idx, 0] * spinor[0] + csr._D[idx, 1] * spinor[1] @@ -496,8 +525,11 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): DM = csr._D[idx, 0] # Create the DM csr matrix. - csrDM = csr_matrix((DM, csr.col[idx], _ncol_to_indptr(csr.ncol)), - shape=(self.shape[:2]), dtype=DM.dtype) + csrDM = csr_matrix( + (DM, csr.col[idx], _ncol_to_indptr(csr.ncol)), + shape=(self.shape[:2]), + dtype=DM.dtype, + ) # Clean-up del idx, DM @@ -524,11 +556,11 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): csr_sum[i_s] = csr # Recreate the column-stacked csr matrix - csrDM = ss_hstack(csr_sum, format='csr') + csrDM = ss_hstack(csr_sum, format="csr") del csr, csr_sum # Remove all zero elements (note we use the tolerance here!) - csrDM.data = np.where(np.fabs(csrDM.data) > tol, csrDM.data, 0.) + csrDM.data = np.where(np.fabs(csrDM.data) > tol, csrDM.data, 0.0) # Eliminate zeros and sort indices etc. csrDM.eliminate_zeros() @@ -538,8 +570,10 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): # 1. Ensure the grid has a geometry associated with it lattice = grid.lattice.copy() # Find the periodic directions - pbc = [bc == BC.PERIODIC or geometry.nsc[i] > 1 - for i, bc in enumerate(grid.lattice.boundary_condition[:, 0])] + pbc = [ + bc == BC.PERIODIC or geometry.nsc[i] > 1 + for i, bc in enumerate(grid.lattice.boundary_condition[:, 0]) + ] if grid.geometry is None: # Create the actual geometry that encompass the grid ia, xyz, _ = geometry.within_inf(lattice, periodic=pbc) @@ -570,7 +604,7 @@ def density(self, grid, spinor=None, tol=1e-7, eta=None): a2o = geometry.a2o def xyz2spherical(xyz, offset): - """ Calculate the spherical coordinates from indices """ + """Calculate the spherical coordinates from indices""" rx = xyz[:, 0] - offset[0] ry = xyz[:, 1] - offset[1] rz = xyz[:, 2] - offset[2] @@ -580,7 +614,7 @@ def xyz2spherical(xyz, offset): return rx, ry, rz def xyz2sphericalR(xyz, offset, R): - """ Calculate the spherical coordinates from indices """ + """Calculate the spherical coordinates from indices""" rx = xyz[:, 0] - offset[0] idx = indices_fabs_le(rx, R) ry = xyz[idx, 1] - offset[1] @@ -597,7 +631,7 @@ def xyz2sphericalR(xyz, offset, R): rx = rx[idx] # Calculate radius ** 2 - ix = indices_le(rx ** 2 + ry ** 2 + rz ** 2, R ** 2) + ix = indices_le(rx**2 + ry**2 + rz**2, R**2) idx = idx[ix] if len(idx) == 0: return [], [], [], [] @@ -643,7 +677,7 @@ def xyz2sphericalR(xyz, offset, R): # Extract maximum R R = ia_atom.maxR() - if R <= 0.: + if R <= 0.0: warn(f"Atom '{ia_atom}' does not have a wave-function, skipping atom.") eta.update() continue @@ -679,13 +713,13 @@ def xyz2sphericalR(xyz, offset, R): # This will have a size equal to number of elements times number of # orbitals on this atom # In this way we do not have to calculate the psi_j multiple times - DM_io = csrDM[IO:IO+ia_atom.no, :].tolil() + DM_io = csrDM[IO : IO + ia_atom.no, :].tolil() DM_pj = _a.zerosd([ia_atom.no, grid_xyz.shape[0]]) # Now we perform the loop on the connections for this atom # Remark that we have removed the diagonal atom (it-self) # As that will be calculated in the end - for ja in a_col[a_ptr[ia]:a_ptr[ia+1]]: + for ja in a_col[a_ptr[ia] : a_ptr[ia + 1]]: # Retrieve atom (which contains the orbitals) ja_atom = atoms[ja % na] JO = a2o(ja) @@ -694,7 +728,9 @@ def xyz2sphericalR(xyz, offset, R): ja_xyz = axyz(ja) + cell_offset # Reduce the ia'th grid points to those that connects to the ja'th atom - ja_idx, ja_r, ja_theta, ja_cos_phi = xyz2sphericalR(grid_xyz, ja_xyz, jR) + ja_idx, ja_r, ja_theta, ja_cos_phi = xyz2sphericalR( + grid_xyz, ja_xyz, jR + ) if len(ja_idx) == 0: # Quick step @@ -728,7 +764,7 @@ def xyz2sphericalR(xyz, offset, R): # Now add this orbital to all components for io in IO_range: - DM_pj[io, ja_idx1] += DM_io[io, JO+jo] * psi + DM_pj[io, ja_idx1] += DM_io[io, JO + jo] * psi # Temporary clean up del ja_idx, ja_r, ja_theta, ja_cos_phi @@ -744,8 +780,8 @@ def xyz2sphericalR(xyz, offset, R): # Only loop halve the range. # This is because: triu + tril(-1).transpose() # removes the lower half of the on-site matrix. - for jo in range(io+1, ia_atom.no): - DM = DM_io[io, off+IO+jo] + for jo in range(io + 1, ia_atom.no): + DM = DM_io[io, off + IO + jo] oj = ia_atom.orbitals[jo] ojR = oj.R @@ -768,14 +804,18 @@ def xyz2sphericalR(xyz, offset, R): ja_cos_phi1 = ia_cos_phi[ja_idx1] # Calculate the psi_j component - DM_pj[io, ja_idx1] += DM * oj.psi_spher(ja_r1, ja_theta1, ja_cos_phi1, cos_phi=True) + DM_pj[io, ja_idx1] += DM * oj.psi_spher( + ja_r1, ja_theta1, ja_cos_phi1, cos_phi=True + ) # Calculate the psi_i component # Note that this one *also* zeroes points outside the shell # I.e. this step is important because it "nullifies" all but points where # orbital io is defined. - psi = ia_atom.orbitals[io].psi_spher(ia_r, ia_theta, ia_cos_phi, cos_phi=True) - DM_pj[io, :] += DM_io[io, off+IO+io] * psi + psi = ia_atom.orbitals[io].psi_spher( + ia_r, ia_theta, ia_cos_phi, cos_phi=True + ) + DM_pj[io, :] += DM_io[io, off + IO + io] * psi DM_pj[io, :] *= psi # Temporary clean up @@ -797,7 +837,7 @@ def xyz2sphericalR(xyz, offset, R): @set_module("sisl.physics") class DensityMatrix(_densitymatrix): - """ Sparse density matrix object + """Sparse density matrix object Assigning or changing elements is as easy as with standard `numpy` assignments: @@ -866,7 +906,7 @@ class DensityMatrix(_densitymatrix): """ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): - """ Initialize density matrix """ + """Initialize density matrix""" super().__init__(geometry, dim, dtype, nnzpr, **kwargs) self._reset() @@ -878,12 +918,12 @@ def _reset(self): @property def D(self): - r""" Access the density matrix elements """ + r"""Access the density matrix elements""" self._def_dim = self.UP return self - def orbital_momentum(self, projection='orbital', method='onsite'): - r""" Calculate orbital angular momentum on either atoms or orbitals + def orbital_momentum(self, projection="orbital", method="onsite"): + r"""Calculate orbital angular momentum on either atoms or orbitals Currently this implementation equals the Siesta implementation in that the on-site approximation is enforced thus limiting the calculated quantities @@ -911,7 +951,9 @@ def orbital_momentum(self, projection='orbital', method='onsite'): """ # Check that the spin configuration is correct if not self.spin.is_spinorbit: - raise ValueError(f"{self.__class__.__name__}.orbital_momentum requires a spin-orbit matrix") + raise ValueError( + f"{self.__class__.__name__}.orbital_momentum requires a spin-orbit matrix" + ) # First we calculate orb_lmZ = _a.emptyi([self.no, 3]) @@ -951,11 +993,13 @@ def orbital_momentum(self, projection='orbital', method='onsite'): # 3. same quantum number l # 4. different quantum number m # 5. same zeta - onsite_idx = ((aidx == ajdx) & \ - (orb_lmZ[idx, 0] > 0) & \ - (orb_lmZ[idx, 0] == orb_lmZ[jdx, 0]) & \ - (orb_lmZ[idx, 1] != orb_lmZ[jdx, 1]) & \ - (orb_lmZ[idx, 2] == orb_lmZ[jdx, 2])).nonzero()[0] + onsite_idx = ( + (aidx == ajdx) + & (orb_lmZ[idx, 0] > 0) + & (orb_lmZ[idx, 0] == orb_lmZ[jdx, 0]) + & (orb_lmZ[idx, 1] != orb_lmZ[jdx, 1]) + & (orb_lmZ[idx, 2] == orb_lmZ[jdx, 2]) + ).nonzero()[0] # clean variables we don't need del aidx, ajdx @@ -1097,10 +1141,12 @@ def Lc(idx, idx_l, DM, sub): l = np.zeros([3, geom.na], dtype=L.dtype) add.at(l.T, geom.o2a(np.arange(geom.no)), L.T) return l - raise ValueError(f"{self.__class__.__name__}.orbital_momentum must define projection to be 'orbital' or 'atom'.") + raise ValueError( + f"{self.__class__.__name__}.orbital_momentum must define projection to be 'orbital' or 'atom'." + ) - def Dk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the density matrix for a given k-point + def Dk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the density matrix for a given k-point Creation and return of the density matrix for a given k-point (default to Gamma). @@ -1155,8 +1201,8 @@ def Dk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): """ pass - def dDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the density matrix derivative for a given k-point + def dDk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the density matrix derivative for a given k-point Creation and return of the density matrix derivative for a given k-point (default to Gamma). @@ -1210,8 +1256,8 @@ def dDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs) """ pass - def ddDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the density matrix double derivative for a given k-point + def ddDk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the density matrix double derivative for a given k-point Creation and return of the density matrix double derivative for a given k-point (default to Gamma). @@ -1267,7 +1313,7 @@ def ddDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs @staticmethod def read(sile, *args, **kwargs): - """ Reads density matrix from `Sile` using `read_density_matrix`. + """Reads density matrix from `Sile` using `read_density_matrix`. Parameters ---------- @@ -1280,19 +1326,21 @@ def read(sile, *args, **kwargs): # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_density_matrix(*args, **kwargs) else: - with get_sile(sile, mode='r') as fh: + with get_sile(sile, mode="r") as fh: return fh.read_density_matrix(*args, **kwargs) def write(self, sile, *args, **kwargs) -> None: - """ Writes a density matrix to the `Sile` as implemented in the :code:`Sile.write_density_matrix` method """ + """Writes a density matrix to the `Sile` as implemented in the :code:`Sile.write_density_matrix` method""" # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): sile.write_density_matrix(self, *args, **kwargs) else: - with get_sile(sile, mode='w') as fh: + with get_sile(sile, mode="w") as fh: fh.write_density_matrix(self, *args, **kwargs) diff --git a/src/sisl/physics/distribution.py b/src/sisl/physics/distribution.py index 9fb0ab396d..06ae01b610 100644 --- a/src/sisl/physics/distribution.py +++ b/src/sisl/physics/distribution.py @@ -28,14 +28,14 @@ _pi = np.pi _sqrt_2pi = (2 * _pi) ** 0.5 -__all__ = ['get_distribution', 'gaussian', 'lorentzian'] -__all__ += ['fermi_dirac', 'bose_einstein', 'cold'] -__all__ += ['step_function', 'heaviside'] +__all__ = ["get_distribution", "gaussian", "lorentzian"] +__all__ += ["fermi_dirac", "bose_einstein", "cold"] +__all__ += ["step_function", "heaviside"] @set_module("sisl.physics") -def get_distribution(method, smearing=0.1, x0=0.): - r""" Create a distribution function, Gaussian, Lorentzian etc. +def get_distribution(method, smearing=0.1, x0=0.0): + r"""Create a distribution function, Gaussian, Lorentzian etc. See the details regarding the distributions in their respective documentation. @@ -53,27 +53,29 @@ def get_distribution(method, smearing=0.1, x0=0.): callable a function which accepts one argument """ - m = method.lower().replace('-', '_') - if m in ('gauss', 'gaussian'): + m = method.lower().replace("-", "_") + if m in ("gauss", "gaussian"): return partial(gaussian, sigma=smearing, x0=x0) - elif m in ('lorentz', 'lorentzian'): + elif m in ("lorentz", "lorentzian"): return partial(lorentzian, gamma=smearing, x0=x0) - elif m in ('fd', 'fermi', 'fermi_dirac'): + elif m in ("fd", "fermi", "fermi_dirac"): return partial(fermi_dirac, kT=smearing, mu=x0) - elif m in ('be', 'bose_einstein'): + elif m in ("be", "bose_einstein"): return partial(bose_einstein, kT=smearing, mu=x0) - elif m in ('cold'): + elif m in ("cold"): return partial(cold, kT=smearing, mu=x0) - elif m in ('step', 'step_function'): + elif m in ("step", "step_function"): return partial(step_function, x0=x0) - elif m in ('heavi', 'heavy', 'heaviside'): + elif m in ("heavi", "heavy", "heaviside"): return partial(heaviside, x0=x0) - raise ValueError(f"get_distribution does not implement the {method} distribution function, have you mispelled?") + raise ValueError( + f"get_distribution does not implement the {method} distribution function, have you mispelled?" + ) @set_module("sisl.physics") -def gaussian(x, sigma=0.1, x0=0.): - r""" Gaussian distribution function +def gaussian(x, sigma=0.1, x0=0.0): + r"""Gaussian distribution function .. math:: G(x,\sigma,x_0) = \frac{1}{\sqrt{2\pi\sigma^2}}\exp\Big[\frac{- (x - x_0)^2}{2\sigma^2}\Big] @@ -92,13 +94,13 @@ def gaussian(x, sigma=0.1, x0=0.): numpy.ndarray the Gaussian distribution, same length as `x` """ - dx = (x - x0) / (sigma * 2 ** 0.5) - return exp(- dx * dx) / (_sqrt_2pi * sigma) + dx = (x - x0) / (sigma * 2**0.5) + return exp(-dx * dx) / (_sqrt_2pi * sigma) @set_module("sisl.physics") -def lorentzian(x, gamma=0.1, x0=0.): - r""" Lorentzian distribution function +def lorentzian(x, gamma=0.1, x0=0.0): + r"""Lorentzian distribution function .. math:: L(x,\gamma,x_0) = \frac{1}{\pi}\frac{\gamma}{(x-x_0)^2 + \gamma^2} @@ -121,8 +123,8 @@ def lorentzian(x, gamma=0.1, x0=0.): @set_module("sisl.physics") -def fermi_dirac(E, kT=0.1, mu=0.): - r""" Fermi-Dirac distribution function +def fermi_dirac(E, kT=0.1, mu=0.0): + r"""Fermi-Dirac distribution function .. math:: n_F(E,k_BT,\mu) = \frac{1}{\exp\Big[\frac{E - \mu}{k_BT}\Big] + 1} @@ -141,12 +143,12 @@ def fermi_dirac(E, kT=0.1, mu=0.): numpy.ndarray the Fermi-Dirac distribution, same length as `E` """ - return 1. / (expm1((E - mu) / kT) + 2.) + return 1.0 / (expm1((E - mu) / kT) + 2.0) @set_module("sisl.physics") -def bose_einstein(E, kT=0.1, mu=0.): - r""" Bose-Einstein distribution function +def bose_einstein(E, kT=0.1, mu=0.0): + r"""Bose-Einstein distribution function .. math:: n_B(E,k_BT,\mu) = \frac{1}{\exp\Big[\frac{E - \mu}{k_BT}\Big] - 1} @@ -165,11 +167,11 @@ def bose_einstein(E, kT=0.1, mu=0.): numpy.ndarray the Bose-Einstein distribution, same length as `E` """ - return 1. / expm1((E - mu) / kT) + return 1.0 / expm1((E - mu) / kT) @set_module("sisl.physics") -def cold(E, kT=0.1, mu=0.): +def cold(E, kT=0.1, mu=0.0): r""" Cold smearing function For more details see :cite:`Marzari1999`. @@ -193,12 +195,12 @@ def cold(E, kT=0.1, mu=0.): numpy.ndarray the Cold smearing distribution function, same length as `E` """ - x = - (E - mu) / kT - 1 / 2 ** 0.5 - return 0.5 + 0.5 * erf(x) + exp(- x * x) / _sqrt_2pi + x = -(E - mu) / kT - 1 / 2**0.5 + return 0.5 + 0.5 * erf(x) + exp(-x * x) / _sqrt_2pi @set_module("sisl.physics") -def heaviside(x, x0=0.): +def heaviside(x, x0=0.0): r""" Heaviside step function .. math:: @@ -227,13 +229,13 @@ def heaviside(x, x0=0.): the Heaviside step function distribution, same length as `x` """ H = np.zeros_like(x) - H[x > x0] = 1. + H[x > x0] = 1.0 H[x == x0] = 0.5 return H @set_module("sisl.physics") -def step_function(x, x0=0.): +def step_function(x, x0=0.0): r""" Step function, also known as :math:`1 - H(x)` This function equals one minus the Heaviside step function @@ -264,6 +266,6 @@ def step_function(x, x0=0.): the step function distribution, same length as `x` """ s = np.ones_like(x) - s[x > x0] = 0. + s[x > x0] = 0.0 s[x == x0] = 0.5 return s diff --git a/src/sisl/physics/dynamicalmatrix.py b/src/sisl/physics/dynamicalmatrix.py index 0bdf61ba55..6358996f6b 100644 --- a/src/sisl/physics/dynamicalmatrix.py +++ b/src/sisl/physics/dynamicalmatrix.py @@ -9,7 +9,7 @@ from .phonon import EigenmodePhonon, EigenvaluePhonon from .sparse import SparseOrbitalBZ -__all__ = ['DynamicalMatrix'] +__all__ = ["DynamicalMatrix"] def _correct_hw(hw): @@ -23,7 +23,7 @@ def _correct_hw(hw): @set_module("sisl.physics") class DynamicalMatrix(SparseOrbitalBZ): - """ Dynamical matrix of a geometry """ + """Dynamical matrix of a geometry""" def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): super().__init__(geometry, dim, dtype, nnzpr, **kwargs) @@ -37,12 +37,12 @@ def _reset(self): @property def D(self): - r""" Access the dynamical matrix elements """ + r"""Access the dynamical matrix elements""" self._def_dim = 0 return self - def Dk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the dynamical matrix for a given k-point + def Dk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the dynamical matrix for a given k-point Creation and return of the dynamical matrix for a given k-point (default to Gamma). @@ -93,8 +93,8 @@ def Dk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): """ pass - def dDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the dynamical matrix derivative for a given k-point + def dDk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the dynamical matrix derivative for a given k-point Creation and return of the dynamical matrix derivative for a given k-point (default to Gamma). @@ -144,8 +144,8 @@ def dDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs) """ pass - def ddDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the dynamical matrix double derivative for a given k-point + def ddDk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the dynamical matrix double derivative for a given k-point Creation and return of the dynamical matrix double derivative for a given k-point (default to Gamma). @@ -196,7 +196,7 @@ def ddDk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs pass def apply_newton(self): - """ Sometimes the dynamical matrix does not obey Newtons 3rd law. + """Sometimes the dynamical matrix does not obey Newtons 3rd law. We correct the dynamical matrix by imposing zero force. @@ -208,7 +208,7 @@ def apply_newton(self): d_uc = lil_matrix((no, no), dtype=dyn_sc.dtype) for i, _ in self.lattice: - d_uc[:, :] += dyn_sc[:, i*no: (i+1)*no] + d_uc[:, :] += dyn_sc[:, i * no : (i + 1) * no] # A CSC matrix is faster to slice for columns d_uc = d_uc.tocsc() @@ -219,7 +219,6 @@ def apply_newton(self): MM = np.empty([len(om)], np.float64) for ja in self.geometry: - # Create conversion to force-constant in units of the on-site mass scaled # dynamical matrix. MM[:] = om[:] / om[ja] @@ -249,8 +248,8 @@ def apply_newton(self): del d_uc - def eigenvalue(self, k=(0, 0, 0), gauge='R', **kwargs): - """ Calculate the eigenvalues at `k` and return an `EigenvaluePhonon` object containing all eigenvalues for a given `k` + def eigenvalue(self, k=(0, 0, 0), gauge="R", **kwargs): + """Calculate the eigenvalues at `k` and return an `EigenvaluePhonon` object containing all eigenvalues for a given `k` Parameters ---------- @@ -273,15 +272,15 @@ def eigenvalue(self, k=(0, 0, 0), gauge='R', **kwargs): ------- EigenvaluePhonon """ - if kwargs.pop('sparse', False): + if kwargs.pop("sparse", False): hw = self.eigsh(k, gauge=gauge, eigvals_only=True, **kwargs) else: hw = self.eigh(k, gauge, eigvals_only=True, **kwargs) - info = {'k': k, 'gauge': gauge} + info = {"k": k, "gauge": gauge} return EigenvaluePhonon(_correct_hw(hw), self, **info) - def eigenmode(self, k=(0, 0, 0), gauge='R', **kwargs): - r""" Calculate the eigenmodes at `k` and return an `EigenmodePhonon` object containing all eigenmodes + def eigenmode(self, k=(0, 0, 0), gauge="R", **kwargs): + r"""Calculate the eigenmodes at `k` and return an `EigenmodePhonon` object containing all eigenmodes Note that the phonon modes are _not_ mass-scaled. @@ -306,17 +305,17 @@ def eigenmode(self, k=(0, 0, 0), gauge='R', **kwargs): ------- EigenmodePhonon """ - if kwargs.pop('sparse', False): + if kwargs.pop("sparse", False): hw, v = self.eigsh(k, gauge=gauge, eigvals_only=False, **kwargs) else: hw, v = self.eigh(k, gauge, eigvals_only=False, **kwargs) - info = {'k': k, 'gauge': gauge} + info = {"k": k, "gauge": gauge} # Since eigh returns the eigenvectors [:, i] we have to transpose return EigenmodePhonon(v.T, _correct_hw(hw), self, **info) @staticmethod def read(sile, *args, **kwargs): - """ Reads dynamical matrix from `Sile` using `read_dynamical_matrix`. + """Reads dynamical matrix from `Sile` using `read_dynamical_matrix`. Parameters ---------- @@ -328,19 +327,21 @@ def read(sile, *args, **kwargs): # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_dynamical_matrix(*args, **kwargs) else: - with get_sile(sile, mode='r') as fh: + with get_sile(sile, mode="r") as fh: return fh.read_dynamical_matrix(*args, **kwargs) def write(self, sile, *args, **kwargs) -> None: - """ Writes a dynamical matrix to the `Sile` as implemented in the :code:`Sile.write_dynamical_matrix` method """ + """Writes a dynamical matrix to the `Sile` as implemented in the :code:`Sile.write_dynamical_matrix` method""" # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): sile.write_dynamical_matrix(self, *args, **kwargs) else: - with get_sile(sile, mode='w') as fh: + with get_sile(sile, mode="w") as fh: fh.write_dynamical_matrix(self, *args, **kwargs) diff --git a/src/sisl/physics/electron.py b/src/sisl/physics/electron.py index 0e6e7fa92d..34a11d0790 100644 --- a/src/sisl/physics/electron.py +++ b/src/sisl/physics/electron.py @@ -96,7 +96,7 @@ @set_module("sisl.physics.electron") def DOS(E, eig, distribution="gaussian"): - r""" Calculate the density of states (DOS) for a set of energies, `E`, with a distribution function + r"""Calculate the density of states (DOS) for a set of energies, `E`, with a distribution function The :math:`\mathrm{DOS}(E)` is calculated as: @@ -132,7 +132,7 @@ def DOS(E, eig, distribution="gaussian"): if isinstance(distribution, str): distribution = get_distribution(distribution) - return reduce(lambda DOS, eig: DOS + distribution(E - eig), eig, 0.) + return reduce(lambda DOS, eig: DOS + distribution(E - eig), eig, 0.0) @set_module("sisl.physics.electron") @@ -215,9 +215,11 @@ def PDOS(E, eig, state, S=None, distribution="gaussian", spin=None): # Figure out whether we are dealing with a non-colinear calculation if S is None: + class S: __slots__ = [] shape = (state.shape[1], state.shape[1]) + @staticmethod def dot(v): return v @@ -238,7 +240,9 @@ def dot(v): S = S[::2, ::2] # Initialize data - PDOS = empty([4, state.shape[1] // 2, len(E)], dtype=dtype_complex_to_real(state.dtype)) + PDOS = empty( + [4, state.shape[1] // 2, len(E)], dtype=dtype_complex_to_real(state.dtype) + ) # Do spin-box calculations: # PDOS[0] = total DOS (diagonal) @@ -249,13 +253,13 @@ def dot(v): d = distribution(E - eig[0]).reshape(1, -1) cs = conj(state[0]).reshape(-1, 2) v = S.dot(state[0].reshape(-1, 2)) - D1 = (cs * v).real # uu,dd PDOS - PDOS[0, :, :] = D1.sum(1).reshape(-1, 1) * d # total DOS - PDOS[3, :, :] = (D1[:, 0] - D1[:, 1]).reshape(-1, 1) * d # z-dos - D1 = (cs[:, 1] * v[:, 0]).reshape(-1, 1) # d,u - D2 = (cs[:, 0] * v[:, 1]).reshape(-1, 1) # u,d - PDOS[1, :, :] = (D1.real + D2.real) * d # x-dos - PDOS[2, :, :] = (D2.imag - D1.imag) * d # y-dos + D1 = (cs * v).real # uu,dd PDOS + PDOS[0, :, :] = D1.sum(1).reshape(-1, 1) * d # total DOS + PDOS[3, :, :] = (D1[:, 0] - D1[:, 1]).reshape(-1, 1) * d # z-dos + D1 = (cs[:, 1] * v[:, 0]).reshape(-1, 1) # d,u + D2 = (cs[:, 0] * v[:, 1]).reshape(-1, 1) # u,d + PDOS[1, :, :] = (D1.real + D2.real) * d # x-dos + PDOS[2, :, :] = (D2.imag - D1.imag) * d # y-dos for i in range(1, len(eig)): d = distribution(E - eig[i]).reshape(1, -1) cs = conj(state[i]).reshape(-1, 2) @@ -269,12 +273,14 @@ def dot(v): PDOS[2, :, :] += (D2.imag - D1.imag) * d else: - PDOS = (conj(state[0]) * S.dot(state[0])).real.reshape(-1, 1) \ - * distribution(E - eig[0]).reshape(1, -1) + PDOS = (conj(state[0]) * S.dot(state[0])).real.reshape(-1, 1) * distribution( + E - eig[0] + ).reshape(1, -1) for i in range(1, len(eig)): - PDOS += (conj(state[i]) * S.dot(state[i])).real.reshape(-1, 1) \ - * distribution(E - eig[i]).reshape(1, -1) + PDOS += (conj(state[i]) * S.dot(state[i])).real.reshape( + -1, 1 + ) * distribution(E - eig[i]).reshape(1, -1) PDOS.shape = (1, *PDOS.shape) return PDOS @@ -282,7 +288,7 @@ def dot(v): @set_module("sisl.physics.electron") def COP(E, eig, state, M, distribution="gaussian", tol=1e-10): - r""" Calculate the Crystal Orbital Population for a set of energies, `E`, with a distribution function + r"""Calculate the Crystal Orbital Population for a set of energies, `E`, with a distribution function The :math:`\mathrm{COP}(E)` is calculated as: @@ -337,7 +343,9 @@ def COP(E, eig, state, M, distribution="gaussian", tol=1e-10): if isinstance(distribution, str): distribution = get_distribution(distribution) - assert len(eig) == len(state), "COP: number of eigenvalues and states are not consistent" + assert len(eig) == len( + state + ), "COP: number of eigenvalues and states are not consistent" # get default dtype dtype = dtype_complex_to_real(state.dtype) @@ -347,7 +355,6 @@ def COP(E, eig, state, M, distribution="gaussian", tol=1e-10): n_s = M.shape[1] // M.shape[0] if isinstance(M, _FakeMatrix): - # A fake matrix equals the identity matrix. # Hence we can do all calculations only on the diagonal, # then finally we recreate the full matrix dimensions. @@ -358,7 +365,7 @@ def new_list(bools, tmp, we): if bl: yield tmp * w else: - yield 0. + yield 0.0 for e, s in zip(eig, state): # calculate contribution from this state @@ -370,27 +377,27 @@ def new_list(bools, tmp, we): # Now recreate the full size (in sparse form) idx = np.arange(no) + def tosize(diag, idx): return csr_matrix((diag, (idx, idx)), shape=M.shape) cop = oplist(tosize(d, idx) for d in cop) elif issparse(M): - # create the new list - cop0 = M.multiply(0.).real + cop0 = M.multiply(0.0).real cop = oplist(cop0.copy() for _ in range(len(E))) del cop0 # split M, then we will rejoin afterwards - Ms = [M[:, i*no:(i+1)*no] for i in range(n_s)] + Ms = [M[:, i * no : (i + 1) * no] for i in range(n_s)] def new_list(bools, tmp, we): for bl, w in zip(bools, we): if bl: yield tmp.multiply(w) else: - yield 0. + yield 0.0 for e, s in zip(eig, state): # calculate contribution from this state @@ -470,9 +477,11 @@ def spin_moment(state, S=None, project=False): return spin_moment(state.reshape(1, -1), S, project)[0] if S is None: + class S: __slots__ = [] shape = (state.shape[1] // 2, state.shape[1] // 2) + @staticmethod def dot(v): return v @@ -483,7 +492,10 @@ def dot(v): # see PDOS for details related to the spin-box calculations if project: - s = empty([3, state.shape[0], state.shape[1] // 2], dtype=dtype_complex_to_real(state.dtype)) + s = empty( + [3, state.shape[0], state.shape[1] // 2], + dtype=dtype_complex_to_real(state.dtype), + ) for i in range(len(state)): cs = conj(state[i]).reshape(-1, 2) @@ -552,19 +564,25 @@ def spin_squared(state_alpha, state_beta, S=None): """ if state_alpha.ndim == 1: if state_beta.ndim == 1: - Sa, Sb = spin_squared(state_alpha.reshape(1, -1), state_beta.reshape(1, -1), S) + Sa, Sb = spin_squared( + state_alpha.reshape(1, -1), state_beta.reshape(1, -1), S + ) return oplist((Sa[0], Sb[0])) return spin_squared(state_alpha.reshape(1, -1), state_beta, S) elif state_beta.ndim == 1: return spin_squared(state_alpha, state_beta.reshape(1, -1), S) if state_alpha.shape[1] != state_beta.shape[1]: - raise ValueError("spin_squared requires alpha and beta states to have same number of orbitals") + raise ValueError( + "spin_squared requires alpha and beta states to have same number of orbitals" + ) if S is None: + class S: __slots__ = [] shape = (state_alpha.shape[1], state_alpha.shape[1]) + @staticmethod def dot(v): return v @@ -604,8 +622,10 @@ def dot(v): _velocity_const = 1 / constant.hbar("eV ps") -def _velocity_matrix_non_ortho(state, dHk, energy, dSk, degenerate, degenerate_dir, dtype): - r""" For states in a non-orthogonal basis """ +def _velocity_matrix_non_ortho( + state, dHk, energy, dSk, degenerate, degenerate_dir, dtype +): + r"""For states in a non-orthogonal basis""" # All matrix elements along the 3 directions n = state.shape[0] @@ -614,8 +634,8 @@ def _velocity_matrix_non_ortho(state, dHk, energy, dSk, degenerate, degenerate_d # Decouple the degenerate states if not degenerate is None: degenerate_dir = _a.asarrayd(degenerate_dir) - degenerate_dir /= (degenerate_dir ** 2).sum() ** 0.5 - deg_dHk = sum(d*dh for d, dh in zip(degenerate_dir, dHk)) + degenerate_dir /= (degenerate_dir**2).sum() ** 0.5 + deg_dHk = sum(d * dh for d, dh in zip(degenerate_dir, dHk)) for deg in degenerate: # Set the average energy e = np.average(energy[deg]) @@ -624,13 +644,15 @@ def _velocity_matrix_non_ortho(state, dHk, energy, dSk, degenerate, degenerate_d # Now diagonalize to find the contributions from individual states # then re-construct the seperated degenerate states # Since we do this for all directions we should decouple them all - state[deg] = degenerate_decouple(state[deg], deg_dHk - sum(d * e * ds for d, ds in zip(degenerate_dir, dSk))) + state[deg] = degenerate_decouple( + state[deg], + deg_dHk - sum(d * e * ds for d, ds in zip(degenerate_dir, dSk)), + ) del deg_dHk # Since they depend on the state energies and dSk we have to loop them individually. cs = conj(state) for s, e in enumerate(energy): - # Since dHk *may* be a csr_matrix or sparse, we have to do it like # this. A sparse matrix cannot be re-shaped with an extra dimension. v[0, s] = cs @ (dHk[0] - e * dSk[0]).dot(state[s]) @@ -642,7 +664,7 @@ def _velocity_matrix_non_ortho(state, dHk, energy, dSk, degenerate, degenerate_d def _velocity_matrix_ortho(state, dHk, degenerate, degenerate_dir, dtype): - r""" For states in an orthogonal basis """ + r"""For states in an orthogonal basis""" # All matrix elements along the 3 directions n = state.shape[0] @@ -651,8 +673,8 @@ def _velocity_matrix_ortho(state, dHk, degenerate, degenerate_dir, dtype): # Decouple the degenerate states if not degenerate is None: degenerate_dir = _a.asarrayd(degenerate_dir) - degenerate_dir /= (degenerate_dir ** 2).sum() ** 0.5 - deg_dHk = sum(d*dh for d, dh in zip(degenerate_dir, dHk)) + degenerate_dir /= (degenerate_dir**2).sum() ** 0.5 + deg_dHk = sum(d * dh for d, dh in zip(degenerate_dir, dHk)) for deg in degenerate: # Now diagonalize to find the contributions from individual states # then re-construct the seperated degenerate states @@ -671,9 +693,10 @@ def _velocity_matrix_ortho(state, dHk, degenerate, degenerate_dir, dtype): @set_module("sisl.physics.electron") -def berry_curvature(state, energy, dHk, dSk=None, - degenerate=None, degenerate_dir=(1, 1, 1)): - r""" Calculate the Berry curvature matrix for a set of states (using Kubo) +def berry_curvature( + state, energy, dHk, dSk=None, degenerate=None, degenerate_dir=(1, 1, 1) +): + r"""Calculate the Berry curvature matrix for a set of states (using Kubo) The Berry curvature is calculated using the following expression (:math:`\alpha`, :math:`\beta` corresponding to Cartesian directions): @@ -720,23 +743,31 @@ def berry_curvature(state, energy, dHk, dSk=None, Berry flux with final dimension ``(3, 3, state.shape[0])`` """ if state.ndim == 1: - return berry_curvature(state.reshape(1, -1), energy, dHk, dSk, degenerate, degenerate_dir)[0] + return berry_curvature( + state.reshape(1, -1), energy, dHk, dSk, degenerate, degenerate_dir + )[0] - dtype = find_common_type([state.dtype, dHk[0].dtype, dtype_real_to_complex(state.dtype)], []) + dtype = find_common_type( + [state.dtype, dHk[0].dtype, dtype_real_to_complex(state.dtype)], [] + ) if dSk is None: v_matrix = _velocity_matrix_ortho(state, dHk, degenerate, degenerate_dir, dtype) else: - v_matrix = _velocity_matrix_non_ortho(state, dHk, energy, dSk, degenerate, degenerate_dir, dtype) - warn("berry_curvature calculation for non-orthogonal basis sets are not tested! Do not expect this to be correct!") + v_matrix = _velocity_matrix_non_ortho( + state, dHk, energy, dSk, degenerate, degenerate_dir, dtype + ) + warn( + "berry_curvature calculation for non-orthogonal basis sets are not tested! Do not expect this to be correct!" + ) return _berry_curvature(v_matrix, energy) # This reverses the velocity unit (squared since Berry curvature is v.v) -_berry_curvature_const = 1 / _velocity_const ** 2 +_berry_curvature_const = 1 / _velocity_const**2 def _berry_curvature(v_M, energy): - r""" Calculate Berry curvature for a given velocity matrix """ + r"""Calculate Berry curvature for a given velocity matrix""" # All matrix elements along the 3 directions N = v_M.shape[1] @@ -755,16 +786,21 @@ def _berry_curvature(v_M, energy): sigma[:, :, s] = ((de * v_M[:, s]) @ v_M[:, :, s].T).imag # negative here - sigma *= - _berry_curvature_const + sigma *= -_berry_curvature_const return sigma @set_module("sisl.physics.electron") -def conductivity(bz, distribution="fermi-dirac", method="ahc", - degenerate=1.e-5, degenerate_dir=(1, 1, 1), - *, - eigenstate_kwargs=None): - r""" Electronic conductivity for a given `BrillouinZone` integral +def conductivity( + bz, + distribution="fermi-dirac", + method="ahc", + degenerate=1.0e-5, + degenerate_dir=(1, 1, 1), + *, + eigenstate_kwargs=None, +): + r"""Electronic conductivity for a given `BrillouinZone` integral Currently the *only* implemented method is the anomalous Hall conductivity (AHC, see :cite:`Wang2006`) which may be calculated as: @@ -809,7 +845,9 @@ def conductivity(bz, distribution="fermi-dirac", method="ahc", # Currently we require the conductivity calculation to *only* accept Hamiltonians if not isinstance(bz.parent, Hamiltonian): - raise SislError("conductivity: requires the Brillouin zone object to contain a Hamiltonian!") + raise SislError( + "conductivity: requires the Brillouin zone object to contain a Hamiltonian!" + ) if isinstance(distribution, str): distribution = get_distribution(distribution) @@ -819,22 +857,30 @@ def conductivity(bz, distribution="fermi-dirac", method="ahc", method = method.lower() if method == "ahc": + def _ahc(es): occ = distribution(es.eig) - bc = es.berry_curvature(degenerate=degenerate, degenerate_dir=degenerate_dir) + bc = es.berry_curvature( + degenerate=degenerate, degenerate_dir=degenerate_dir + ) return bc @ occ vol, dim = bz.volume(ret_dim=True) if dim == 0: - raise SislError(f"conductivity: found a dimensionality of 0 which is non-physical") + raise SislError( + f"conductivity: found a dimensionality of 0 which is non-physical" + ) - cond = bz.apply.average.eigenstate(**eigenstate_kwargs, - wrap=_ahc) * (-constant.G0 / (4*np.pi)) + cond = bz.apply.average.eigenstate(**eigenstate_kwargs, wrap=_ahc) * ( + -constant.G0 / (4 * np.pi) + ) # Convert the dimensions from S/m^D to S/cm^D cond /= vol * units(f"Ang^{dim}", f"cm^{dim}") - warn("conductivity: be aware that the units are currently not tested, please provide feedback!") + warn( + "conductivity: be aware that the units are currently not tested, please provide feedback!" + ) else: raise SislError("conductivity: requires the method to be [ahc]") @@ -843,9 +889,16 @@ def _ahc(es): @set_module("sisl.physics.electron") -def berry_phase(contour, sub=None, eigvals=False, closed=True, method="berry", - *, - eigenstate_kwargs=None, ret_overlap=False): +def berry_phase( + contour, + sub=None, + eigvals=False, + closed=True, + method="berry", + *, + eigenstate_kwargs=None, + ret_overlap=False, +): r""" Calculate the Berry-phase on a loop path The Berry phase for a single Bloch state is calculated using the discretized formula: @@ -937,22 +990,28 @@ def berry_phase(contour, sub=None, eigvals=False, closed=True, method="berry", # Currently we require the Berry phase calculation to *only* accept Hamiltonians if not isinstance(contour.parent, Hamiltonian): - raise SislError("berry_phase: requires the Brillouin zone object to contain a Hamiltonian!") + raise SislError( + "berry_phase: requires the Brillouin zone object to contain a Hamiltonian!" + ) if eigenstate_kwargs is None: eigenstate_kwargs = {} if contour.parent.orthogonal: + def _lowdin(state): pass + else: gauge = eigenstate_kwargs.get("gauge", "R") + def _lowdin(state): - """ change state to the lowdin state, assuming everything is in R gauge - So needs to be done before changing gauge """ - S12 = sqrth(state.parent.Sk(state.info["k"], - gauge=gauge, format="array"), - overwrite_a=True) + """change state to the lowdin state, assuming everything is in R gauge + So needs to be done before changing gauge""" + S12 = sqrth( + state.parent.Sk(state.info["k"], gauge=gauge, format="array"), + overwrite_a=True, + ) state.state[:, :] = (S12 @ state.state.T).T method, *opts = method.lower().split(":") @@ -967,11 +1026,13 @@ def _lowdin(state): _process = dot if "svd" in opts: + def _process(prd, overlap): U, _, V = svd_destroy(overlap) return dot(prd, U @ V) if sub is None: + def _berry(eigenstates): # Grab the first one to be able to form a loop first = next(eigenstates) @@ -996,6 +1057,7 @@ def _berry(eigenstates): return prd else: + def _berry(eigenstates): nonlocal sub first = next(eigenstates) @@ -1032,7 +1094,7 @@ def _berry(eigenstates): @set_module("sisl.physics.electron") def wavefunction(v, grid, geometry=None, k=None, spinor=0, spin=None, eta=None): - r""" Add the wave-function (`Orbital.psi`) component of each orbital to the grid + r"""Add the wave-function (`Orbital.psi`) component of each orbital to the grid This routine calculates the real-space wave-function components in the specified grid. @@ -1108,40 +1170,54 @@ def wavefunction(v, grid, geometry=None, k=None, spinor=0, spin=None, eta=None): if k is None: k = v.info.get("k", k) elif not np.allclose(k, v.info.get("k", k)): - raise ValueError(f"wavefunction: k passed and k in info does not match: {k} and {v.info.get('k')}") + raise ValueError( + f"wavefunction: k passed and k in info does not match: {k} and {v.info.get('k')}" + ) v = v.state if geometry is None: geometry = grid.geometry if geometry is None: - raise SislError("wavefunction: did not find a usable Geometry through keywords or the Grid!") + raise SislError( + "wavefunction: did not find a usable Geometry through keywords or the Grid!" + ) # We cannot move stuff since outside stuff may rely on exact coordinates. # If people have out-liers, they should do it them-selves. # We'll do this and warn if they are dissimilar. dxyz = geometry.lattice.cell2length(1e-6).sum(0) - dxyz = geometry.move(dxyz).translate2uc(axes=(0, 1, 2)).move(-dxyz).xyz - geometry.xyz + dxyz = ( + geometry.move(dxyz).translate2uc(axes=(0, 1, 2)).move(-dxyz).xyz - geometry.xyz + ) if not np.allclose(dxyz, 0): - info(f"wavefunction: coordinates may be outside your primary unit-cell. " - "Translating all into the primary unit cell could disable this information") + info( + f"wavefunction: coordinates may be outside your primary unit-cell. " + "Translating all into the primary unit cell could disable this information" + ) # In case the user has passed several vectors we sum them to plot the summed state if v.ndim == 2: if v.shape[0] > 1: - info(f"wavefunction: summing {v.shape[0]} different state coefficients, will continue silently!") + info( + f"wavefunction: summing {v.shape[0]} different state coefficients, will continue silently!" + ) v = v.sum(0) if spin is None: if len(v) // 2 == geometry.no: # We can see from the input that the vector *must* be a non-colinear calculation v = v.reshape(-1, 2)[:, spinor] - info("wavefunction: assumes the input wavefunction coefficients to originate from a non-colinear calculation!") + info( + "wavefunction: assumes the input wavefunction coefficients to originate from a non-colinear calculation!" + ) elif spin.kind > Spin.POLARIZED: # For non-colinear cases the user selects the spinor component. v = v.reshape(-1, 2)[:, spinor] if len(v) != geometry.no: - raise ValueError("wavefunction: require wavefunction coefficients corresponding to number of orbitals in the geometry.") + raise ValueError( + "wavefunction: require wavefunction coefficients corresponding to number of orbitals in the geometry." + ) # Check for k-points k = _a.asarrayd(k) @@ -1156,7 +1232,9 @@ def wavefunction(v, grid, geometry=None, k=None, spinor=0, spin=None, eta=None): # Likewise if a k-point has been passed. is_complex = np.iscomplexobj(v) or has_k if is_complex and not np.iscomplexobj(grid.grid): - raise SislError("wavefunction: input coefficients are complex, while grid only contains real.") + raise SislError( + "wavefunction: input coefficients are complex, while grid only contains real." + ) if is_complex: psi_init = _a.zerosz @@ -1178,16 +1256,23 @@ def wavefunction(v, grid, geometry=None, k=None, spinor=0, spin=None, eta=None): old_err = np.seterr(divide="ignore", invalid="ignore") addouter = add.outer + def idx2spherical(ix, iy, iz, offset, dc, R): - """ Calculate the spherical coordinates from indices """ - rx = addouter(addouter(ix * dc[0, 0], iy * dc[1, 0]), iz * dc[2, 0] - offset[0]).ravel() - ry = addouter(addouter(ix * dc[0, 1], iy * dc[1, 1]), iz * dc[2, 1] - offset[1]).ravel() - rz = addouter(addouter(ix * dc[0, 2], iy * dc[1, 2]), iz * dc[2, 2] - offset[2]).ravel() + """Calculate the spherical coordinates from indices""" + rx = addouter( + addouter(ix * dc[0, 0], iy * dc[1, 0]), iz * dc[2, 0] - offset[0] + ).ravel() + ry = addouter( + addouter(ix * dc[0, 1], iy * dc[1, 1]), iz * dc[2, 1] - offset[1] + ).ravel() + rz = addouter( + addouter(ix * dc[0, 2], iy * dc[1, 2]), iz * dc[2, 2] - offset[2] + ).ravel() # Total size of the indices n = rx.shape[0] # Reduce our arrays to where the radius is "fine" - idx = indices_le(rx ** 2 + ry ** 2 + rz ** 2, R ** 2) + idx = indices_le(rx**2 + ry**2 + rz**2, R**2) rx = rx[idx] ry = ry[idx] rz = rz[idx] @@ -1225,7 +1310,7 @@ def idx2spherical(ix, iy, iz, offset, dc, R): if len(ia) == 0: continue R = atom.maxR() - all_negative_R = all_negative_R and R < 0. + all_negative_R = all_negative_R and R < 0.0 # Now do it for all the atoms to get indices of the middle of # the atoms @@ -1238,7 +1323,9 @@ def idx2spherical(ix, iy, iz, offset, dc, R): idx_mm[ia, 1, :] = idxM * R + idx if all_negative_R: - raise SislError("wavefunction: Cannot create wavefunction since no atoms have an associated basis-orbital on a real-space grid") + raise SislError( + "wavefunction: Cannot create wavefunction since no atoms have an associated basis-orbital on a real-space grid" + ) # Now we have min-max for all atoms # When we run the below loop all indices can be retrieved by looking @@ -1254,8 +1341,10 @@ def idx2spherical(ix, iy, iz, offset, dc, R): lattice = grid.lattice.copy() # Find the periodic directions - pbc = [bc == BC.PERIODIC or geometry.nsc[i] > 1 - for i, bc in enumerate(grid.lattice.boundary_condition[:, 0])] + pbc = [ + bc == BC.PERIODIC or geometry.nsc[i] > 1 + for i, bc in enumerate(grid.lattice.boundary_condition[:, 0]) + ] if grid.geometry is None: # Create the actual geometry that encompass the grid ia, xyz, _ = geometry.within_inf(lattice, periodic=pbc) @@ -1265,7 +1354,7 @@ def idx2spherical(ix, iy, iz, offset, dc, R): # Instead of looping all atoms in the supercell we find the exact atoms # and their supercell indices. # plus some tolerance - add_R = _a.fulld(3, geometry.maxR()) + 1.e-6 + add_R = _a.fulld(3, geometry.maxR()) + 1.0e-6 # Calculate the required additional vectors required to increase the fictitious # supercell by add_R in each direction. # For extremely skewed lattices this will be way too much, hence we make @@ -1296,8 +1385,10 @@ def idx2spherical(ix, iy, iz, offset, dc, R): # Extract maximum R R = atom.maxR() - if R <= 0.: - warn(f"wavefunction: Atom '{atom}' does not have a wave-function, skipping atom.") + if R <= 0.0: + warn( + f"wavefunction: Atom '{atom}' does not have a wave-function, skipping atom." + ) eta.update() continue @@ -1307,8 +1398,14 @@ def idx2spherical(ix, iy, iz, offset, dc, R): idxM = ceil(idx_mm[ia, 1, :] + idx).astype(int32) + 1 # Fast check whether we can skip this point - if idxm[0] >= shape[0] or idxm[1] >= shape[1] or idxm[2] >= shape[2] or \ - idxM[0] <= 0 or idxM[1] <= 0 or idxM[2] <= 0: + if ( + idxm[0] >= shape[0] + or idxm[1] >= shape[1] + or idxm[2] >= shape[2] + or idxM[0] <= 0 + or idxM[1] <= 0 + or idxM[2] <= 0 + ): eta.update() continue @@ -1328,9 +1425,14 @@ def idx2spherical(ix, iy, iz, offset, dc, R): # Now idxm/M contains min/max indices used # Convert to spherical coordinates - n, idx, r, theta, phi = idx2spherical(arangei(idxm[0], idxM[0]), - arangei(idxm[1], idxM[1]), - arangei(idxm[2], idxM[2]), xyz, dcell, R) + n, idx, r, theta, phi = idx2spherical( + arangei(idxm[0], idxM[0]), + arangei(idxm[1], idxM[1]), + arangei(idxm[2], idxM[2]), + xyz, + dcell, + R, + ) # Get initial orbital io = geometry.a2o(ia) @@ -1343,12 +1445,13 @@ def idx2spherical(ix, iy, iz, offset, dc, R): # Loop on orbitals on this atom, grouped by radius for os in atom.iter(True): - # Get the radius of orbitals (os) oR = os[0].R - if oR <= 0.: - warn(f"wavefunction: Orbital(s) '{os}' does not have a wave-function, skipping orbital!") + if oR <= 0.0: + warn( + f"wavefunction: Orbital(s) '{os}' does not have a wave-function, skipping orbital!" + ) # Skip these orbitals io += len(os) continue @@ -1370,7 +1473,9 @@ def idx2spherical(ix, iy, iz, offset, dc, R): # Loop orbitals with the same radius for o in os: # Evaluate psi component of the wavefunction and add it for this atom - psi[idx1] += o.psi_spher(r1, theta1, phi1, cos_phi=True) * (v[io] * phase) + psi[idx1] += o.psi_spher(r1, theta1, phi1, cos_phi=True) * ( + v[io] * phase + ) io += 1 # Clean-up @@ -1378,7 +1483,7 @@ def idx2spherical(ix, iy, iz, offset, dc, R): # Convert to correct shape and add the current atom contribution to the wavefunction psi.shape = idxM - idxm - grid.grid[idxm[0]:idxM[0], idxm[1]:idxM[1], idxm[2]:idxM[2]] += psi + grid.grid[idxm[0] : idxM[0], idxm[1] : idxM[1], idxm[2] : idxM[2]] += psi # Clean-up del psi @@ -1397,14 +1502,14 @@ class _electron_State: __slots__ = [] def __is_nc(self): - """ Internal routine to check whether this is a non-colinear calculation """ + """Internal routine to check whether this is a non-colinear calculation""" try: return not self.parent.spin.is_diagonal except Exception: return False def Sk(self, format=None, spin=None): - r""" Retrieve the overlap matrix corresponding to the originating parent structure. + r"""Retrieve the overlap matrix corresponding to the originating parent structure. When ``self.parent`` is a Hamiltonian this will return :math:`\mathbf S(k)` for the :math:`k`-point these eigenstates originate from @@ -1424,9 +1529,11 @@ def Sk(self, format=None, spin=None): if isinstance(self.parent, SparseOrbitalBZSpin): # Calculate the overlap matrix if not self.parent.orthogonal: - opt = {"k": self.info.get("k", (0, 0, 0)), - "dtype": self.dtype, - "format": format} + opt = { + "k": self.info.get("k", (0, 0, 0)), + "dtype": self.dtype, + "format": format, + } for key in ("gauge",): val = self.info.get(key, None) if not val is None: @@ -1445,7 +1552,7 @@ def Sk(self, format=None, spin=None): return _FakeMatrix(n, m) def norm2(self, sum=True): - r""" Return a vector with the norm of each state :math:`\langle\psi|\mathbf S|\psi\rangle` + r"""Return a vector with the norm of each state :math:`\langle\psi|\mathbf S|\psi\rangle` :math:`\mathbf S` is the overlap matrix (or basis), for orthogonal basis :math:`\mathbf S \equiv \mathbf I`. @@ -1469,7 +1576,7 @@ def norm2(self, sum=True): return conj(self.state) * S.dot(self.state.T).T def spin_moment(self, project=False): - r""" Calculate spin moment from the states + r"""Calculate spin moment from the states This routine calls `~sisl.physics.electron.spin_moment` with appropriate arguments and returns the spin moment for the states. @@ -1484,7 +1591,7 @@ def spin_moment(self, project=False): return spin_moment(self.state, self.Sk(), project=project) def wavefunction(self, grid, spinor=0, eta=None): - r""" Expand the coefficients as the wavefunction on `grid` *as-is* + r"""Expand the coefficients as the wavefunction on `grid` *as-is* See `~sisl.physics.electron.wavefunction` for argument details, the arguments not present in this method are automatically passed from this object. @@ -1507,28 +1614,30 @@ def wavefunction(self, grid, spinor=0, eta=None): # Retrieve k k = self.info.get("k", _a.zerosd(3)) - wavefunction(self.state, grid, geometry=geometry, k=k, spinor=spinor, spin=spin, eta=eta) + wavefunction( + self.state, grid, geometry=geometry, k=k, spinor=spinor, spin=spin, eta=eta + ) @set_module("sisl.physics.electron") class CoefficientElectron(Coefficient): - r""" Coefficients describing some physical quantity related to electrons """ + r"""Coefficients describing some physical quantity related to electrons""" __slots__ = [] @set_module("sisl.physics.electron") class StateElectron(_electron_State, State): - r""" A state describing a physical quantity related to electrons """ + r"""A state describing a physical quantity related to electrons""" __slots__ = [] @set_module("sisl.physics.electron") class StateCElectron(_electron_State, StateC): - r""" A state describing a physical quantity related to electrons, with associated coefficients of the state """ + r"""A state describing a physical quantity related to electrons, with associated coefficients of the state""" __slots__ = [] def velocity(self, *args, **kwargs): - r""" Calculate velocity for the states + r"""Calculate velocity for the states This routine calls ``derivative(1, *args, **kwargs)`` and returns the velocity for the states. @@ -1560,7 +1669,7 @@ def velocity(self, *args, **kwargs): return v def berry_curvature(self, *args, **kwargs): - r""" Calculate Berry curvature for the states + r"""Calculate Berry curvature for the states This routine calls ``derivative(1, *args, **kwargs, matrix=True)`` and returns the Berry curvature for the states. @@ -1577,7 +1686,7 @@ def berry_curvature(self, *args, **kwargs): return _berry_curvature(v, self.c) def effective_mass(self, *args, **kwargs): - r""" Calculate effective mass tensor for the states, units are (ps/Ang)^2 + r"""Calculate effective mass tensor for the states, units are (ps/Ang)^2 This routine calls ``derivative(2, *args, **kwargs)`` and returns the effective mass for all states. @@ -1605,13 +1714,13 @@ def effective_mass(self, *args, **kwargs): derivative: for details of the implementation """ ieff = self.derivative(2, *args, **kwargs)[1].real - np.divide(_velocity_const ** 2, ieff, where=(ieff != 0), out=ieff) + np.divide(_velocity_const**2, ieff, where=(ieff != 0), out=ieff) return ieff @set_module("sisl.physics.electron") class EigenvalueElectron(CoefficientElectron): - r""" Eigenvalues of electronic states, no eigenvectors retained + r"""Eigenvalues of electronic states, no eigenvectors retained This holds routines that enable the calculation of density of states. """ @@ -1619,11 +1728,11 @@ class EigenvalueElectron(CoefficientElectron): @property def eig(self): - """ Eigenvalues """ + """Eigenvalues""" return self.c def occupation(self, distribution="fermi_dirac"): - r""" Calculate the occupations for the states according to a distribution function + r"""Calculate the occupations for the states according to a distribution function Parameters ---------- @@ -1640,7 +1749,7 @@ def occupation(self, distribution="fermi_dirac"): return distribution(self.eig) def DOS(self, E, distribution="gaussian"): - r""" Calculate DOS for provided energies, `E`. + r"""Calculate DOS for provided energies, `E`. This routine calls `sisl.physics.electron.DOS` with appropriate arguments and returns the DOS. @@ -1652,7 +1761,7 @@ def DOS(self, E, distribution="gaussian"): @set_module("sisl.physics.electron") class EigenvectorElectron(StateElectron): - r""" Eigenvectors of electronic states, no eigenvalues retained + r"""Eigenvectors of electronic states, no eigenvalues retained This holds routines that enable the calculation of spin moments. """ @@ -1661,7 +1770,7 @@ class EigenvectorElectron(StateElectron): @set_module("sisl.physics.electron") class EigenstateElectron(StateCElectron): - r""" Eigen states of electrons with eigenvectors and eigenvalues. + r"""Eigen states of electrons with eigenvectors and eigenvalues. This holds routines that enable the calculation of (projected) density of states, spin moments (spin texture). @@ -1670,11 +1779,11 @@ class EigenstateElectron(StateCElectron): @property def eig(self): - r""" Eigenvalues for each state """ + r"""Eigenvalues for each state""" return self.c def occupation(self, distribution="fermi_dirac"): - r""" Calculate the occupations for the states according to a distribution function + r"""Calculate the occupations for the states according to a distribution function Parameters ---------- @@ -1691,7 +1800,7 @@ def occupation(self, distribution="fermi_dirac"): return distribution(self.eig) def DOS(self, E, distribution="gaussian"): - r""" Calculate DOS for provided energies, `E`. + r"""Calculate DOS for provided energies, `E`. This routine calls `sisl.physics.electron.DOS` with appropriate arguments and returns the DOS. @@ -1701,24 +1810,31 @@ def DOS(self, E, distribution="gaussian"): return DOS(E, self.c, distribution) def PDOS(self, E, distribution="gaussian"): - r""" Calculate PDOS for provided energies, `E`. + r"""Calculate PDOS for provided energies, `E`. This routine calls `~sisl.physics.electron.PDOS` with appropriate arguments and returns the PDOS. See `~sisl.physics.electron.PDOS` for argument details. """ - return PDOS(E, self.c, self.state, self.Sk(), distribution, getattr(self.parent, "spin", None)) + return PDOS( + E, + self.c, + self.state, + self.Sk(), + distribution, + getattr(self.parent, "spin", None), + ) def COP(self, E, M, *args, **kwargs): - r""" Calculate COP for provided energies, `E` using matrix `M` + r"""Calculate COP for provided energies, `E` using matrix `M` This routine calls `~sisl.physics.electron.COP` with appropriate arguments. """ return COP(E, self.c, self.state, M, *args, **kwargs) def COOP(self, E, *args, **kwargs): - r""" Calculate COOP for provided energies, `E`. + r"""Calculate COOP for provided energies, `E`. This routine calls `~sisl.physics.electron.COP` with appropriate arguments. """ @@ -1727,7 +1843,7 @@ def COOP(self, E, *args, **kwargs): return COP(E, self.c, self.state, Sk, *args, **kwargs) def COHP(self, E, *args, **kwargs): - r""" Calculate COHP for provided energies, `E`. + r"""Calculate COHP for provided energies, `E`. This routine calls `~sisl.physics.electron.COP` with appropriate arguments. """ diff --git a/src/sisl/physics/energydensitymatrix.py b/src/sisl/physics/energydensitymatrix.py index 3fd411a0df..d712e8642c 100644 --- a/src/sisl/physics/energydensitymatrix.py +++ b/src/sisl/physics/energydensitymatrix.py @@ -14,7 +14,7 @@ @set_module("sisl.physics") class EnergyDensityMatrix(_densitymatrix): - """ Sparse energy density matrix object + """Sparse energy density matrix object Assigning or changing elements is as easy as with standard `numpy` assignments: @@ -94,12 +94,12 @@ def _reset(self): @property def E(self): - r""" Access the energy density matrix elements """ + r"""Access the energy density matrix elements""" self._def_dim = self.UP return self - def Ek(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the energy density matrix for a given k-point + def Ek(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the energy density matrix for a given k-point Creation and return of the energy density matrix for a given k-point (default to Gamma). @@ -154,8 +154,8 @@ def Ek(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): """ pass - def dEk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the energy density matrix derivative for a given k-point + def dEk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the energy density matrix derivative for a given k-point Creation and return of the energy density matrix derivative for a given k-point (default to Gamma). @@ -209,8 +209,8 @@ def dEk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs) """ pass - def ddEk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the energy density matrix double derivative for a given k-point + def ddEk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the energy density matrix double derivative for a given k-point Creation and return of the energy density matrix double derivative for a given k-point (default to Gamma). @@ -265,7 +265,7 @@ def ddEk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs pass def shift(self, E, DM): - r""" Shift the energy density matrix to a common energy by using a reference density matrix + r"""Shift the energy density matrix to a common energy by using a reference density matrix This is equal to performing this operation: @@ -285,14 +285,16 @@ def shift(self, E, DM): density matrix corresponding to the same geometry """ if not self.spsame(DM): - raise SislError(f"{self.__class__.__name__}.shift requires the input DM to have " - "the same sparsity as the shifted object.") + raise SislError( + f"{self.__class__.__name__}.shift requires the input DM to have " + "the same sparsity as the shifted object." + ) E = _a.asarrayd(E) if E.size == 1: E = np.tile(E, 2) - if np.abs(E).sum() == 0.: + if np.abs(E).sum() == 0.0: # When the energy is zero, there is no shift return @@ -301,7 +303,7 @@ def shift(self, E, DM): @staticmethod def read(sile, *args, **kwargs): - """ Reads density matrix from `Sile` using `read_energy_density_matrix`. + """Reads density matrix from `Sile` using `read_energy_density_matrix`. Parameters ---------- @@ -314,19 +316,21 @@ def read(sile, *args, **kwargs): # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_energy_density_matrix(*args, **kwargs) else: - with get_sile(sile, mode='r') as fh: + with get_sile(sile, mode="r") as fh: return fh.read_energy_density_matrix(*args, **kwargs) def write(self, sile, *args, **kwargs) -> None: - """ Writes a density matrix to the `Sile` as implemented in the :code:`Sile.write_energy_density_matrix` method """ + """Writes a density matrix to the `Sile` as implemented in the :code:`Sile.write_energy_density_matrix` method""" # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): sile.write_energy_density_matrix(self, *args, **kwargs) else: - with get_sile(sile, mode='w') as fh: + with get_sile(sile, mode="w") as fh: fh.write_energy_density_matrix(self, *args, **kwargs) diff --git a/src/sisl/physics/hamiltonian.py b/src/sisl/physics/hamiltonian.py index a75d775780..6d3224793d 100644 --- a/src/sisl/physics/hamiltonian.py +++ b/src/sisl/physics/hamiltonian.py @@ -15,7 +15,7 @@ @set_module("sisl.physics") class Hamiltonian(SparseOrbitalBZSpin): - """ Sparse Hamiltonian matrix object + """Sparse Hamiltonian matrix object Assigning or changing Hamiltonian elements is as easy as with standard `numpy` assignments: @@ -84,7 +84,7 @@ class Hamiltonian(SparseOrbitalBZSpin): """ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): - """ Initialize Hamiltonian """ + """Initialize Hamiltonian""" super().__init__(geometry, dim, dtype, nnzpr, **kwargs) self._reset() @@ -96,12 +96,12 @@ def _reset(self): @property def H(self): - r""" Access the Hamiltonian elements """ + r"""Access the Hamiltonian elements""" self._def_dim = self.UP return self - def Hk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the Hamiltonian for a given k-point + def Hk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the Hamiltonian for a given k-point Creation and return of the Hamiltonian for a given k-point (default to Gamma). @@ -157,8 +157,8 @@ def Hk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): """ pass - def dHk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the Hamiltonian derivative for a given k-point + def dHk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the Hamiltonian derivative for a given k-point Creation and return of the Hamiltonian derivative for a given k-point (default to Gamma). @@ -212,8 +212,8 @@ def dHk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs) """ pass - def ddHk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs): - r""" Setup the Hamiltonian double derivative for a given k-point + def ddHk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): + r"""Setup the Hamiltonian double derivative for a given k-point Creation and return of the Hamiltonian double derivative for a given k-point (default to Gamma). @@ -268,7 +268,7 @@ def ddHk(self, k=(0, 0, 0), dtype=None, gauge='R', format='csr', *args, **kwargs pass def shift(self, E): - r""" Shift the electronic structure by a constant energy + r"""Shift the electronic structure by a constant energy This is equal to performing this operation: @@ -288,7 +288,7 @@ def shift(self, E): if E.size == 1: E = np.tile(E, 2) - if np.abs(E).sum() == 0.: + if np.abs(E).sum() == 0.0: # When the energy is zero, there is no shift return @@ -302,8 +302,8 @@ def shift(self, E): for i in range(self.spin.spinor): self._csr._D[:, i] += self._csr._D[:, self.S_idx] * E[i] - def eigenvalue(self, k=(0, 0, 0), gauge='R', **kwargs): - """ Calculate the eigenvalues at `k` and return an `EigenvalueElectron` object containing all eigenvalues for a given `k` + def eigenvalue(self, k=(0, 0, 0), gauge="R", **kwargs): + """Calculate the eigenvalues at `k` and return an `EigenvalueElectron` object containing all eigenvalues for a given `k` Parameters ---------- @@ -330,11 +330,11 @@ def eigenvalue(self, k=(0, 0, 0), gauge='R', **kwargs): EigenvalueElectron """ format = kwargs.pop("format", None) - if kwargs.pop('sparse', False): + if kwargs.pop("sparse", False): e = self.eigsh(k, gauge=gauge, eigvals_only=True, **kwargs) else: e = self.eigh(k, gauge, eigvals_only=True, **kwargs) - info = {'k': k, 'gauge': gauge} + info = {"k": k, "gauge": gauge} for name in ["spin"]: if name in kwargs: info[name] = kwargs[name] @@ -342,8 +342,8 @@ def eigenvalue(self, k=(0, 0, 0), gauge='R', **kwargs): info["format"] = format return EigenvalueElectron(e, self, **info) - def eigenstate(self, k=(0, 0, 0), gauge='R', **kwargs): - """ Calculate the eigenstates at `k` and return an `EigenstateElectron` object containing all eigenstates + def eigenstate(self, k=(0, 0, 0), gauge="R", **kwargs): + """Calculate the eigenstates at `k` and return an `EigenstateElectron` object containing all eigenstates Parameters ---------- @@ -370,11 +370,11 @@ def eigenstate(self, k=(0, 0, 0), gauge='R', **kwargs): EigenstateElectron """ format = kwargs.pop("format", None) - if kwargs.pop('sparse', False): + if kwargs.pop("sparse", False): e, v = self.eigsh(k, gauge=gauge, eigvals_only=False, **kwargs) else: e, v = self.eigh(k, gauge, eigvals_only=False, **kwargs) - info = {'k': k, 'gauge': gauge} + info = {"k": k, "gauge": gauge} for name in ["spin"]: if name in kwargs: info[name] = kwargs[name] @@ -385,7 +385,7 @@ def eigenstate(self, k=(0, 0, 0), gauge='R', **kwargs): @staticmethod def read(sile, *args, **kwargs): - """ Reads Hamiltonian from `Sile` using `read_hamiltonian`. + """Reads Hamiltonian from `Sile` using `read_hamiltonian`. Parameters ---------- @@ -398,25 +398,27 @@ def read(sile, *args, **kwargs): # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_hamiltonian(*args, **kwargs) else: - with get_sile(sile, mode='r') as fh: + with get_sile(sile, mode="r") as fh: return fh.read_hamiltonian(*args, **kwargs) def write(self, sile, *args, **kwargs) -> None: - """ Writes a Hamiltonian to the `Sile` as implemented in the :code:`Sile.write_hamiltonian` method """ + """Writes a Hamiltonian to the `Sile` as implemented in the :code:`Sile.write_hamiltonian` method""" # This only works because, they *must* # have been imported previously from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): sile.write_hamiltonian(self, *args, **kwargs) else: - with get_sile(sile, mode='w') as fh: + with get_sile(sile, mode="w") as fh: fh.write_hamiltonian(self, *args, **kwargs) - def fermi_level(self, bz=None, q=None, distribution='fermi_dirac', q_tol=1e-10): - """ Calculate the Fermi-level using a Brillouinzone sampling and a target charge + def fermi_level(self, bz=None, q=None, distribution="fermi_dirac", q_tol=1e-10): + """Calculate the Fermi-level using a Brillouinzone sampling and a target charge The Fermi-level will be calculated using an iterative approach by first calculating all eigenvalues and subsequently fitting the Fermi level to the final charge (`q`). @@ -443,6 +445,7 @@ def fermi_level(self, bz=None, q=None, distribution='fermi_dirac', q_tol=1e-10): if bz is None: # Gamma-point only from .brillouinzone import BrillouinZone + bz = BrillouinZone(self) else: # Overwrite the parent in bz @@ -456,9 +459,11 @@ def fermi_level(self, bz=None, q=None, distribution='fermi_dirac', q_tol=1e-10): # Ensure we have an "array" in case of spin-polarized calculations q = _a.asarrayd(q) - if np.any(q <= 0.): - raise ValueError(f"{self.__class__.__name__}.fermi_level cannot calculate the Fermi level " - "for 0 electrons.") + if np.any(q <= 0.0): + raise ValueError( + f"{self.__class__.__name__}.fermi_level cannot calculate the Fermi level " + "for 0 electrons." + ) if isinstance(distribution, str): distribution = get_distribution(distribution) @@ -496,8 +501,10 @@ def _Ef(q, eig): if self.spin.is_polarized and q.size == 2: if np.any(q >= len(self)): - raise ValueError(f"{self.__class__.__name__}.fermi_level cannot calculate the Fermi level " - "for electrons ({q}) equal to or above number of orbitals ({len(self)}).") + raise ValueError( + f"{self.__class__.__name__}.fermi_level cannot calculate the Fermi level " + "for electrons ({q}) equal to or above number of orbitals ({len(self)})." + ) # We need to do Fermi-level separately since the user requests # separate fillings Ef = _a.emptyd(2) @@ -507,11 +514,12 @@ def _Ef(q, eig): # Ensure a single charge q = q.sum() if q >= len(self): - raise ValueError(f"{self.__class__.__name__}.fermi_level cannot calculate the Fermi level " - "for electrons ({q}) equal to or above number of orbitals ({len(self)}).") + raise ValueError( + f"{self.__class__.__name__}.fermi_level cannot calculate the Fermi level " + "for electrons ({q}) equal to or above number of orbitals ({len(self)})." + ) if self.spin.is_polarized: - Ef = _Ef(q, np.concatenate([eigh(spin=0), - eigh(spin=1)], axis=1)) + Ef = _Ef(q, np.concatenate([eigh(spin=0), eigh(spin=1)], axis=1)) else: Ef = _Ef(q, eigh()) diff --git a/src/sisl/physics/overlap.py b/src/sisl/physics/overlap.py index d834e3154e..2d78520c20 100644 --- a/src/sisl/physics/overlap.py +++ b/src/sisl/physics/overlap.py @@ -12,7 +12,7 @@ @set_module("sisl.physics") class Overlap(SparseOrbitalBZ): - r""" Sparse overlap matrix object + r"""Sparse overlap matrix object The Overlap object contains orbital overlaps. It should be used when the overlaps are not associated with another physical object such as a Hamiltonian, as is the case with eg. Siesta onlyS outputs. @@ -35,7 +35,7 @@ class Overlap(SparseOrbitalBZ): """ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): - r""" Initialize Overlap """ + r"""Initialize Overlap""" # Since this *is* the overlap matrix, we should never use the # orthogonal keyword kwargs["orthogonal"] = True @@ -50,13 +50,13 @@ def _reset(self): @property def S(self): - r""" Access the overlap elements """ + r"""Access the overlap elements""" self._def_dim = 0 return self @classmethod def fromsp(cls, geometry, P, **kwargs): - r""" Create an Overlap object from a preset `Geometry` and a sparse matrix + r"""Create an Overlap object from a preset `Geometry` and a sparse matrix The passed sparse matrix is in one of `scipy.sparse` formats. @@ -84,7 +84,7 @@ def fromsp(cls, geometry, P, **kwargs): @staticmethod def read(sile, *args, **kwargs): - """ Reads Overlap from `Sile` using `read_overlap`. + """Reads Overlap from `Sile` using `read_overlap`. Parameters ---------- @@ -95,17 +95,19 @@ def read(sile, *args, **kwargs): * : args passed directly to ``read_overlap(,**)`` """ from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): return sile.read_overlap(*args, **kwargs) else: - with get_sile(sile, mode='r') as fh: + with get_sile(sile, mode="r") as fh: return fh.read_overlap(*args, **kwargs) def write(self, sile, *args, **kwargs) -> None: - """ Writes the Overlap to the `Sile` as implemented in the :code:`Sile.write_overlap` method """ + """Writes the Overlap to the `Sile` as implemented in the :code:`Sile.write_overlap` method""" from sisl.io import BaseSile, get_sile + if isinstance(sile, BaseSile): sile.write_overlap(self, *args, **kwargs) else: - with get_sile(sile, mode='w') as fh: + with get_sile(sile, mode="w") as fh: fh.write_overlap(self, *args, **kwargs) diff --git a/src/sisl/physics/phonon.py b/src/sisl/physics/phonon.py index 355a7107eb..e9df643da6 100644 --- a/src/sisl/physics/phonon.py +++ b/src/sisl/physics/phonon.py @@ -43,14 +43,14 @@ from .electron import PDOS as electron_PDOS from .state import Coefficient, State, StateC, degenerate_decouple -__all__ = ['DOS', 'PDOS'] -__all__ += ['CoefficientPhonon', 'ModePhonon', 'ModeCPhonon'] -__all__ += ['EigenvaluePhonon', 'EigenvectorPhonon', 'EigenmodePhonon'] +__all__ = ["DOS", "PDOS"] +__all__ += ["CoefficientPhonon", "ModePhonon", "ModeCPhonon"] +__all__ += ["EigenvaluePhonon", "EigenvectorPhonon", "EigenmodePhonon"] @set_module("sisl.physics.phonon") -def DOS(E, hw, distribution='gaussian'): - r""" Calculate the density of modes (DOS) for a set of energies, `E`, with a distribution function +def DOS(E, hw, distribution="gaussian"): + r"""Calculate the density of modes (DOS) for a set of energies, `E`, with a distribution function The :math:`\mathrm{DOS}(E)` is calculated as: @@ -85,8 +85,8 @@ def DOS(E, hw, distribution='gaussian'): @set_module("sisl.physics.phonon") -def PDOS(E, mode, hw, distribution='gaussian'): - r""" Calculate the projected density of modes (PDOS) onto each each atom and direction for a set of energies, `E`, with a distribution function +def PDOS(E, mode, hw, distribution="gaussian"): + r"""Calculate the projected density of modes (PDOS) onto each each atom and direction for a set of energies, `E`, with a distribution function The :math:`\mathrm{PDOS}(E)` is calculated as: @@ -127,41 +127,46 @@ def PDOS(E, mode, hw, distribution='gaussian'): # dDk is in [Ang * eV ** 2] # velocity units in Ang/ps -_velocity_const = 1 / constant.hbar('eV ps') +_velocity_const = 1 / constant.hbar("eV ps") -_displacement_const = (2 * units('Ry', 'eV') * (constant.m_e / constant.m_p)) ** 0.5 * units('Bohr', 'Ang') +_displacement_const = ( + 2 * units("Ry", "eV") * (constant.m_e / constant.m_p) +) ** 0.5 * units("Bohr", "Ang") @set_module("sisl.physics.phonon") class CoefficientPhonon(Coefficient): - """ Coefficients describing some physical quantity related to phonons """ + """Coefficients describing some physical quantity related to phonons""" + __slots__ = [] @set_module("sisl.physics.phonon") class ModePhonon(State): - """ A mode describing a physical quantity related to phonons """ + """A mode describing a physical quantity related to phonons""" + __slots__ = [] @property def mode(self): - """ Eigenmodes (states) """ + """Eigenmodes (states)""" return self.state @set_module("sisl.physics.phonon") class ModeCPhonon(StateC): - """ A mode describing a physical quantity related to phonons, with associated coefficients of the mode """ + """A mode describing a physical quantity related to phonons, with associated coefficients of the mode""" + __slots__ = [] @property def mode(self): - """ Eigenmodes (states) """ + """Eigenmodes (states)""" return self.state def velocity(self, *args, **kwargs): - r""" Calculate velocity of the modes + r"""Calculate velocity of the modes This routine calls `derivative` with appropriate arguments (1st order derivative) and returns the velocity for the modes. @@ -190,19 +195,20 @@ def velocity(self, *args, **kwargs): @set_module("sisl.physics.phonon") class EigenvaluePhonon(CoefficientPhonon): - """ Eigenvalues of phonon modes, no eigenmodes retained + """Eigenvalues of phonon modes, no eigenmodes retained This holds routines that enable the calculation of density of states. """ + __slots__ = [] @property def hw(self): - r""" Eigenmode values in units of :math:`\hbar \omega` [eV] """ + r"""Eigenmode values in units of :math:`\hbar \omega` [eV]""" return self.c - def occupation(self, distribution='bose_einstein'): - """ Calculate the occupations for the states according to a distribution function + def occupation(self, distribution="bose_einstein"): + """Calculate the occupations for the states according to a distribution function Parameters ---------- @@ -218,8 +224,8 @@ def occupation(self, distribution='bose_einstein'): distribution = get_distribution(distribution) return distribution(self.hw) - def DOS(self, E, distribution='gaussian'): - r""" Calculate DOS for provided energies, `E`. + def DOS(self, E, distribution="gaussian"): + r"""Calculate DOS for provided energies, `E`. This routine calls `sisl.physics.phonon.DOS` with appropriate arguments and returns the DOS. @@ -231,25 +237,27 @@ def DOS(self, E, distribution='gaussian'): @set_module("sisl.physics.phonon") class EigenvectorPhonon(ModePhonon): - """ Eigenvectors of phonon modes, no eigenvalues retained """ + """Eigenvectors of phonon modes, no eigenvalues retained""" + __slots__ = [] @set_module("sisl.physics.phonon") class EigenmodePhonon(ModeCPhonon): - """ Eigenmodes of phonons with eigenvectors and eigenvalues. + """Eigenmodes of phonons with eigenvectors and eigenvalues. This holds routines that enable the calculation of (projected) density of states. """ + __slots__ = [] @property def hw(self): - r""" Eigenmode values in units of :math:`\hbar \omega` [eV] """ + r"""Eigenmode values in units of :math:`\hbar \omega` [eV]""" return self.c - def occupation(self, distribution='bose_einstein'): - """ Calculate the occupations for the states according to a distribution function + def occupation(self, distribution="bose_einstein"): + """Calculate the occupations for the states according to a distribution function Parameters ---------- @@ -265,8 +273,8 @@ def occupation(self, distribution='bose_einstein'): distribution = get_distribution(distribution) return distribution(self.hw) - def DOS(self, E, distribution='gaussian'): - r""" Calculate DOS for provided energies, `E`. + def DOS(self, E, distribution="gaussian"): + r"""Calculate DOS for provided energies, `E`. This routine calls `sisl.physics.phonon.DOS` with appropriate arguments and returns the DOS. @@ -275,8 +283,8 @@ def DOS(self, E, distribution='gaussian'): """ return DOS(E, self.hw, distribution) - def PDOS(self, E, distribution='gaussian'): - r""" Calculate PDOS for provided energies, `E`. + def PDOS(self, E, distribution="gaussian"): + r"""Calculate PDOS for provided energies, `E`. This routine calls `~sisl.physics.phonon.PDOS` with appropriate arguments and returns the PDOS. @@ -286,7 +294,7 @@ def PDOS(self, E, distribution='gaussian'): return PDOS(E, self.mode, self.hw, distribution) def displacement(self): - r""" Calculate real-space displacements for a given mode (in units of the characteristic length) + r"""Calculate real-space displacements for a given mode (in units of the characteristic length) The displacements per mode may be written as: @@ -317,7 +325,7 @@ def displacement(self): idx = (self.c == 0).nonzero()[0] mode = self.mode U = mode.copy() - U[idx, :] = 0. + U[idx, :] = 0.0 # Now create the remaining displacements idx = delete(_a.arangei(U.shape[0]), idx) @@ -326,6 +334,8 @@ def displacement(self): factor = _displacement_const / fabs(self.c[idx]).reshape(-1, 1) ** 0.5 U.shape = (U.shape[0], -1, 3) - U[idx] = (mode[idx, :] * factor).reshape(len(idx), -1, 3) / self.parent.mass.reshape(1, -1, 1) ** 0.5 + U[idx] = (mode[idx, :] * factor).reshape( + len(idx), -1, 3 + ) / self.parent.mass.reshape(1, -1, 1) ** 0.5 U = np.swapaxes(U, 0, 2) return U diff --git a/src/sisl/physics/self_energy.py b/src/sisl/physics/self_energy.py index 2361c52ea6..cffdf4b702 100644 --- a/src/sisl/physics/self_energy.py +++ b/src/sisl/physics/self_energy.py @@ -35,7 +35,7 @@ @set_module("sisl.physics") class SelfEnergy: - r""" Self-energy object able to calculate the dense self-energy for a given sparse matrix + r"""Self-energy object able to calculate the dense self-energy for a given sparse matrix The self-energy object contains a `SparseGeometry` object which, in it-self contains the geometry. @@ -44,7 +44,7 @@ class SelfEnergy: """ def __init__(self, *args, **kwargs): - r""" Self-energy class for constructing a self-energy. """ + r"""Self-energy class for constructing a self-energy.""" pass def __len__(self): @@ -53,7 +53,7 @@ def __len__(self): @staticmethod def se2broadening(SE): - r""" Calculate the broadening matrix from the self-energy + r"""Calculate the broadening matrix from the self-energy .. math:: \boldsymbol\Gamma = i(\boldsymbol\Sigma - \boldsymbol \Sigma ^\dagger) @@ -66,14 +66,14 @@ def se2broadening(SE): return 1j * (SE - conjugate(SE.T)) def _setup(self, *args, **kwargs): - """ Class specific setup routine """ + """Class specific setup routine""" pass def self_energy(self, *args, **kwargs): raise NotImplementedError def broadening_matrix(self, *args, **kwargs): - r""" Calculate the broadening matrix by first calculating the self-energy + r"""Calculate the broadening matrix by first calculating the self-energy Any arguments that is passed to this method is directly passed to `self_energy`. @@ -112,13 +112,13 @@ def broadening_matrix(self, *args, **kwargs): return self.se2broadening(self.self_energy(*args, **kwargs)) def __getattr__(self, attr): - r""" Overload attributes from the hosting object """ + r"""Overload attributes from the hosting object""" pass @set_module("sisl.physics") class WideBandSE(SelfEnergy): - r""" Self-energy object with a wide-band electronic structure + r"""Self-energy object with a wide-band electronic structure Such a self-energy only have imaginary components on the diagonal, with all of them being equal to the `eta` value. @@ -143,7 +143,7 @@ def __len__(self): return self._N def self_energy(self, *args, **kwargs): - r""" Return a dense matrix with the self-energy + r"""Return a dense matrix with the self-energy Parameters ---------- @@ -151,20 +151,20 @@ def self_energy(self, *args, **kwargs): locally override the `eta` value for the object """ # note the sign (-) - eta = - kwargs.get("eta", self.eta) - return np.diag(np.repeat(1j*eta, self._N)) + eta = -kwargs.get("eta", self.eta) + return np.diag(np.repeat(1j * eta, self._N)) - def broadening_matrix(self, E=0., *args, **kwargs): + def broadening_matrix(self, E=0.0, *args, **kwargs): # note the sign (+) eta = kwargs.get("eta", self.eta) - return np.diag(np.repeat(np.complex128(2*eta), self._N)) + return np.diag(np.repeat(np.complex128(2 * eta), self._N)) broadening_matrix.__doc__ = SelfEnergy.broadening_matrix.__doc__ @set_module("sisl.physics") class SemiInfinite(SelfEnergy): - r""" Self-energy object able to calculate the dense self-energy for a given `SparseGeometry` in a semi-infinite chain. + r"""Self-energy object able to calculate the dense self-energy for a given `SparseGeometry` in a semi-infinite chain. Parameters ---------- @@ -177,7 +177,7 @@ class SemiInfinite(SelfEnergy): """ def __init__(self, spgeom, infinite, eta=1e-4): - """ Create a `SelfEnergy` object from any `SparseGeometry` """ + """Create a `SelfEnergy` object from any `SparseGeometry`""" self.eta = eta # Determine whether we are in plus/minus direction @@ -186,7 +186,9 @@ def __init__(self, spgeom, infinite, eta=1e-4): elif infinite.startswith("-"): self.semi_inf_dir = -1 else: - raise ValueError(f"{self.__class__.__name__} infinite keyword does not start with `+` or `-`.") + raise ValueError( + f"{self.__class__.__name__} infinite keyword does not start with `+` or `-`." + ) # Determine the direction INF = infinite.upper() @@ -197,41 +199,49 @@ def __init__(self, spgeom, infinite, eta=1e-4): elif INF.endswith("C"): self.semi_inf = 2 else: - raise ValueError(f"{self.__class__.__name__} infinite keyword does not end with `A`, `B` or `C`.") + raise ValueError( + f"{self.__class__.__name__} infinite keyword does not end with `A`, `B` or `C`." + ) # Check that the Hamiltonian does have a non-zero V along the semi-infinite direction if spgeom.geometry.lattice.nsc[self.semi_inf] == 1: - warn("Creating a semi-infinite self-energy with no couplings along the semi-infinite direction") + warn( + "Creating a semi-infinite self-energy with no couplings along the semi-infinite direction" + ) # Finalize the setup by calling the class specific routine self._setup(spgeom) def __str__(self): - """ String representation of SemiInfinite """ - return "{0}{{direction: {1}{2}}}".format(self.__class__.__name__, - {-1: "-", 1: "+"}.get(self.semi_inf_dir), - {0: "A", 1: "B", 2: "C"}.get(self.semi_inf)) + """String representation of SemiInfinite""" + return "{0}{{direction: {1}{2}}}".format( + self.__class__.__name__, + {-1: "-", 1: "+"}.get(self.semi_inf_dir), + {0: "A", 1: "B", 2: "C"}.get(self.semi_inf), + ) @set_module("sisl.physics") class RecursiveSI(SemiInfinite): - """ Self-energy object using the Lopez-Sancho Lopez-Sancho algorithm """ + """Self-energy object using the Lopez-Sancho Lopez-Sancho algorithm""" def __getattr__(self, attr): - """ Overload attributes from the hosting object """ + """Overload attributes from the hosting object""" return getattr(self.spgeom0, attr) def __str__(self): - """ Representation of the RecursiveSI model """ + """Representation of the RecursiveSI model""" direction = {-1: "-", 1: "+"} axis = {0: "A", 1: "B", 2: "C"} - return "{0}{{direction: {1}{2},\n {3}\n}}".format(self.__class__.__name__, - direction[self.semi_inf_dir], axis[self.semi_inf], - str(self.spgeom0).replace("\n", "\n "), + return "{0}{{direction: {1}{2},\n {3}\n}}".format( + self.__class__.__name__, + direction[self.semi_inf_dir], + axis[self.semi_inf], + str(self.spgeom0).replace("\n", "\n "), ) def _setup(self, spgeom): - """ Setup the Lopez-Sancho internals for easy axes """ + """Setup the Lopez-Sancho internals for easy axes""" # Create spgeom0 and spgeom1 self.spgeom0 = spgeom.copy() @@ -257,9 +267,11 @@ def _setup(self, spgeom): diff.eliminate_zeros() rem_nnz = diff.nnz diff = np.amax(diff, axis=(0, 1)) - warn(f"{self.__class__.__name__}: {spgeom.__class__.__name__} has connections across the first neighbouring cell. " - f"{rem_nnz} non-zero values will be forced to 0 as the principal cell-interaction is a requirement. " - f"The maximum values of the removed connections are: {diff}") + warn( + f"{self.__class__.__name__}: {spgeom.__class__.__name__} has connections across the first neighbouring cell. " + f"{rem_nnz} non-zero values will be forced to 0 as the principal cell-interaction is a requirement. " + f"The maximum values of the removed connections are: {diff}" + ) # I.e. we will delete all interactions that are un-important n_s = self.spgeom1.geometry.lattice.n_s @@ -268,8 +280,12 @@ def _setup(self, spgeom): nsc = [None] * 3 nsc[self.semi_inf] = self.semi_inf_dir # Get all supercell indices that we should delete - idx = np.delete(_a.arangei(n_s), - _a.arrayi(self.spgeom1.geometry.lattice.sc_index(nsc))) * n + idx = ( + np.delete( + _a.arangei(n_s), _a.arrayi(self.spgeom1.geometry.lattice.sc_index(nsc)) + ) + * n + ) cols = _a.array_arange(idx, idx + n) # Delete all values in columns, but keep them to retain the supercell information @@ -280,7 +296,7 @@ def __len__(self): return len(self.spgeom0) def green(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, **kwargs): - r""" Return a dense matrix with the bulk Green function at energy `E` and k-point `k` (default Gamma). + r"""Return a dense matrix with the bulk Green function at energy `E` and k-point `k` (default Gamma). Parameters ---------- @@ -301,7 +317,7 @@ def green(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, **kwargs): numpy.ndarray the self-energy corresponding to the semi-infinite direction """ - if E.imag == 0.: + if E.imag == 0.0: E = E.real + 1j * self.eta # Get k-point @@ -316,7 +332,9 @@ def green(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, **kwargs): # As the SparseGeometry inherently works for # orthogonal and non-orthogonal basis, there is no # need to have two algorithms. - GB = sp0.Sk(k, dtype=dtype, format="array") * E - sp0.Pk(k, dtype=dtype, format="array", **kwargs) + GB = sp0.Sk(k, dtype=dtype, format="array") * E - sp0.Pk( + k, dtype=dtype, format="array", **kwargs + ) n = GB.shape[0] ab = empty([n, 2, n], dtype=dtype) @@ -347,19 +365,24 @@ def green(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, **kwargs): getri = linalg_info("getri", dtype) getri_lwork = linalg_info("getri_lwork", dtype) lwork = int(1.01 * _compute_lwork(getri_lwork, n)) + def inv(A): lu, piv, info = getrf(A, overwrite_a=True) if info == 0: x, info = getri(lu, piv, lwork=lwork, overwrite_lu=True) if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not compute the inverse.") + raise ValueError( + f"{self.__class__.__name__}.green could not compute the inverse." + ) return x while True: _, _, tab, info = gesv(GB, ab2, overwrite_a=False, overwrite_b=False) tab.shape = shape if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not solve G x = B system!") + raise ValueError( + f"{self.__class__.__name__}.green could not solve G x = B system!" + ) # Update bulk Green function subtract(GB, matmul(alpha, tab[:, 1, :]), out=GB) @@ -375,10 +398,12 @@ def inv(A): del ab, alpha, beta, ab2, tab return inv(GB) - raise ValueError(f"{self.__class__.__name__}.green could not converge Green function calculation") + raise ValueError( + f"{self.__class__.__name__}.green could not converge Green function calculation" + ) def self_energy(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kwargs): - r""" Return a dense matrix with the self-energy at energy `E` and k-point `k` (default Gamma). + r"""Return a dense matrix with the self-energy at energy `E` and k-point `k` (default Gamma). Parameters ---------- @@ -402,7 +427,7 @@ def self_energy(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kwarg numpy.ndarray the self-energy corresponding to the semi-infinite direction """ - if E.imag == 0.: + if E.imag == 0.0: E = E.real + 1j * self.eta # Get k-point @@ -417,7 +442,9 @@ def self_energy(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kwarg # As the SparseGeometry inherently works for # orthogonal and non-orthogonal basis, there is no # need to have two algorithms. - GB = sp0.Sk(k, dtype=dtype, format="array") * E - sp0.Pk(k, dtype=dtype, format="array", **kwargs) + GB = sp0.Sk(k, dtype=dtype, format="array") * E - sp0.Pk( + k, dtype=dtype, format="array", **kwargs + ) n = GB.shape[0] ab = empty([n, 2, n], dtype=dtype) @@ -451,12 +478,14 @@ def self_energy(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kwarg gesv = linalg_info("gesv", dtype) # Specifying dot with "out" argument should be faster - tmp = empty_like(GS, order='C') + tmp = empty_like(GS, order="C") while True: _, _, tab, info = gesv(GB, ab2, overwrite_a=False, overwrite_b=False) tab.shape = shape if info != 0: - raise ValueError(f"{self.__class__.__name__}.self_energy could not solve G x = B system!") + raise ValueError( + f"{self.__class__.__name__}.self_energy could not solve G x = B system!" + ) matmul(alpha, tab[:, 1, :], out=tmp) # Update bulk Green function @@ -475,12 +504,16 @@ def self_energy(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kwarg del ab, alpha, beta, ab2, tab, GB if bulk: return GS - return - GS + return -GS - raise ValueError(f"{self.__class__.__name__}: could not converge self-energy calculation") + raise ValueError( + f"{self.__class__.__name__}: could not converge self-energy calculation" + ) - def self_energy_lr(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kwargs): - r""" Return two dense matrices with the left/right self-energy at energy `E` and k-point `k` (default Gamma). + def self_energy_lr( + self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kwargs + ): + r"""Return two dense matrices with the left/right self-energy at energy `E` and k-point `k` (default Gamma). Note calculating the LR self-energies simultaneously requires that their chemical potentials are the same. I.e. only when the reference energy is equivalent in the left/right schemes does this make sense. @@ -509,7 +542,7 @@ def self_energy_lr(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kw right : numpy.ndarray the right self-energy """ - if E.imag == 0.: + if E.imag == 0.0: E = E.real + 1j * self.eta # Get k-point @@ -524,7 +557,9 @@ def self_energy_lr(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kw # As the SparseGeometry inherently works for # orthogonal and non-orthogonal basis, there is no # need to have two algorithms. - SmH0 = sp0.Sk(k, dtype=dtype, format="array") * E - sp0.Pk(k, dtype=dtype, format="array", **kwargs) + SmH0 = sp0.Sk(k, dtype=dtype, format="array") * E - sp0.Pk( + k, dtype=dtype, format="array", **kwargs + ) GB = SmH0.copy() n = GB.shape[0] @@ -559,12 +594,14 @@ def self_energy_lr(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kw gesv = linalg_info("gesv", dtype) # Specifying dot with "out" argument should be faster - tmp = empty_like(GS, order='C') + tmp = empty_like(GS, order="C") while True: _, _, tab, info = gesv(GB, ab2, overwrite_a=False, overwrite_b=False) tab.shape = shape if info != 0: - raise ValueError(f"{self.__class__.__name__}.self_energy_lr could not solve G x = B system!") + raise ValueError( + f"{self.__class__.__name__}.self_energy_lr could not solve G x = B system!" + ) matmul(alpha, tab[:, 1, :], out=tmp) # Update bulk Green function @@ -585,18 +622,20 @@ def self_energy_lr(self, E, k=(0, 0, 0), dtype=None, eps=1e-14, bulk=False, **kw # GS is the "right" self-energy if bulk: return GB - GS + SmH0, GS - return GS - GB + SmH0, - GS + return GS - GB + SmH0, -GS # GS is the "left" self-energy if bulk: return GS, GB - GS + SmH0 - return - GS, GS - GB + SmH0 + return -GS, GS - GB + SmH0 - raise ValueError(f"{self.__class__.__name__}: could not converge self-energy (LR) calculation") + raise ValueError( + f"{self.__class__.__name__}: could not converge self-energy (LR) calculation" + ) @set_module("sisl.physics") class RealSpaceSE(SelfEnergy): - r""" Bulk real-space self-energy (or Green function) for a given physical object with periodicity + r"""Bulk real-space self-energy (or Green function) for a given physical object with periodicity The real-space self-energy is calculated via the k-averaged Green function: @@ -657,7 +696,7 @@ class RealSpaceSE(SelfEnergy): """ def __init__(self, parent, semi_axis, k_axes, unfold=(1, 1, 1), **options): - """ Initialize real-space self-energy calculator """ + """Initialize real-space self-energy calculator""" self.parent = parent # Store axes @@ -668,15 +707,21 @@ def __init__(self, parent, semi_axis, k_axes, unfold=(1, 1, 1), **options): s_ax = self._semi_axis k_ax = self._k_axes if s_ax in k_ax: - raise ValueError(f"{self.__class__.__name__} found the self-energy direction to be " - "the same as one of the k-axes, this is not allowed.") + raise ValueError( + f"{self.__class__.__name__} found the self-energy direction to be " + "the same as one of the k-axes, this is not allowed." + ) if np.any(self.parent.nsc[k_ax] < 3): - raise ValueError(f"{self.__class__.__name__} found k-axes without periodicity. " - "Correct k_axes via .set_options.") + raise ValueError( + f"{self.__class__.__name__} found k-axes without periodicity. " + "Correct k_axes via .set_options." + ) if self.parent.nsc[s_ax] != 3: - raise ValueError(f"{self.__class__.__name__} found the self-energy direction to be " - "incompatible with the parent object. It *must* have 3 supercells along the " - "semi-infinite direction.") + raise ValueError( + f"{self.__class__.__name__} found the self-energy direction to be " + "incompatible with the parent object. It *must* have 3 supercells along the " + "semi-infinite direction." + ) # Local variables for the completion of the details self._unfold = _a.arrayi([max(1, un) for un in unfold]) @@ -699,7 +744,7 @@ def __len__(self): return len(self.parent) * np.prod(self._unfold) def __str__(self): - """ String representation of RealSpaceSE """ + """String representation of RealSpaceSE""" d = {"class": self.__class__.__name__} for i in range(3): d[f"u{i}"] = self._unfold[i] @@ -708,13 +753,15 @@ def __str__(self): d["parent"] = str(self.parent).replace("\n", "\n ") d["bz"] = str(self._options["bz"]).replace("\n", "\n ") d["trs"] = str(self._options["trs"]) - return ("{class}{{unfold: [{u0}, {u1}, {u2}],\n " - "semi-axis: {semi}, k-axes: {k}, trs: {trs},\n " - "bz: {bz},\n " - "{parent}\n}}").format(**d) + return ( + "{class}{{unfold: [{u0}, {u1}, {u2}],\n " + "semi-axis: {semi}, k-axes: {k}, trs: {trs},\n " + "bz: {bz},\n " + "{parent}\n}}" + ).format(**d) def set_options(self, **options): - r""" Update options in the real-space self-energy + r"""Update options in the real-space self-energy After updating options one should re-call `initialize` for consistency. @@ -736,7 +783,7 @@ def set_options(self, **options): self._options.update(options) def real_space_parent(self): - """ Return the parent object in the real-space unfolded region """ + """Return the parent object in the real-space unfolded region""" s_ax = self._semi_axis k_ax = self._k_axes # Always start with the semi-infinite direction, since we @@ -756,7 +803,7 @@ def real_space_parent(self): return P0 def real_space_coupling(self, ret_indices=False): - r""" Real-space coupling parent where sites fold into the parent real-space unit cell + r"""Real-space coupling parent where sites fold into the parent real-space unit cell The resulting parent object only contains the inner-cell couplings for the elements that couple out of the real-space matrix. @@ -799,7 +846,9 @@ def real_space_coupling(self, ret_indices=False): n = PC.shape[0] idx = g.lattice.sc_index([0, 0, 0]) cols = _a.arangei(idx * n, (idx + 1) * n) - csr = PC._csr.copy([0]) # we just want the sparse pattern, so forget about the other elements + csr = PC._csr.copy( + [0] + ) # we just want the sparse pattern, so forget about the other elements csr.delete_columns(cols, keep_shape=True) # Now PC only contains couplings along the k and semi-inf directions # Extract the connecting orbitals and reduce them to unique atomic indices @@ -827,7 +876,7 @@ def real_space_coupling(self, ret_indices=False): return PC def initialize(self): - r""" Initialize the internal data-arrays used for efficient calculation of the real-space quantities + r"""Initialize the internal data-arrays used for efficient calculation of the real-space quantities This method should first be called *after* all options has been specified. @@ -874,8 +923,10 @@ def initialize(self): nk[k_ax] = np.ceil(self._options["dk"] * rcell).astype(np.int32) self._options["bz"] = MonkhorstPack(lattice, nk, trs=self._options["trs"]) - def self_energy(self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, **kwargs): - r""" Calculate the real-space self-energy + def self_energy( + self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, **kwargs + ): + r"""Calculate the real-space self-energy The real space self-energy is calculated via: @@ -912,10 +963,24 @@ def self_energy(self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, ** if coupling: orbs = self._calc["orbs"] iorbs = delete(_a.arangei(len(G)), orbs).reshape(-1, 1) - SeH = self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs) + SeH = self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"]( + k, dtype=dtype, **kwargs + ) if bulk: - return solve(G[orbs, orbs.T], eye(orbs.size, dtype=dtype) - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), True, True) - return SeH[orbs, orbs.T].toarray() - solve(G[orbs, orbs.T], eye(orbs.size, dtype=dtype) - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), True, True) + return solve( + G[orbs, orbs.T], + eye(orbs.size, dtype=dtype) + - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), + True, + True, + ) + return SeH[orbs, orbs.T].toarray() - solve( + G[orbs, orbs.T], + eye(orbs.size, dtype=dtype) + - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), + True, + True, + ) # Another way to do the coupling calculation would be the *full* thing # which should always be slower. @@ -923,22 +988,25 @@ def self_energy(self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, ** # since comparing the two yields numerical differences on the order 1e-8 eV depending # on the size of the full matrix G. - #orbs = self._calc["orbs"] - #iorbs = _a.arangei(orbs.size).reshape(1, -1) - #I = zeros([G.shape[0], orbs.size], dtype) + # orbs = self._calc["orbs"] + # iorbs = _a.arangei(orbs.size).reshape(1, -1) + # I = zeros([G.shape[0], orbs.size], dtype) ### Set diagonal - #I[orbs.ravel(), iorbs.ravel()] = 1. - #if bulk: + # I[orbs.ravel(), iorbs.ravel()] = 1. + # if bulk: # return solve(G, I, True, True)[orbs, iorbs] - #return (self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs))[orbs, orbs.T].toarray() \ + # return (self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs))[orbs, orbs.T].toarray() \ # - solve(G, I, True, True)[orbs, iorbs] if bulk: return inv(G, True) - return (self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs)).toarray() - inv(G, True) + return ( + self._calc["S0"](k, dtype=dtype) * E + - self._calc["P0"](k, dtype=dtype, **kwargs) + ).toarray() - inv(G, True) def green(self, E, k=(0, 0, 0), dtype=None, **kwargs): - r""" Calculate the real-space Green function + r"""Calculate the real-space Green function The real space Green function is calculated via: @@ -979,13 +1047,17 @@ def green(self, E, k=(0, 0, 0), dtype=None, **kwargs): k_ax = self._k_axes k = _a.asarrayd(k) - is_k = np.any(k != 0.) + is_k = np.any(k != 0.0) if is_k: axes = [s_ax] + k_ax.tolist() - if np.any(k[axes] != 0.): - raise ValueError(f"{self.__class__.__name__}.green requires the k-point to be zero along the integrated axes.") + if np.any(k[axes] != 0.0): + raise ValueError( + f"{self.__class__.__name__}.green requires the k-point to be zero along the integrated axes." + ) if trs: - raise ValueError(f"{self.__class__.__name__}.green requires a k-point sampled Green function to not use time reversal symmetry.") + raise ValueError( + f"{self.__class__.__name__}.green requires a k-point sampled Green function to not use time reversal symmetry." + ) # Shift k-points to get the correct k-point in the larger one. bz._k += k.reshape(1, 3) @@ -1003,12 +1075,15 @@ def green(self, E, k=(0, 0, 0), dtype=None, **kwargs): getri = linalg_info("getri", dtype) getri_lwork = linalg_info("getri_lwork", dtype) lwork = int(1.01 * _compute_lwork(getri_lwork, len(self._calc["SE"].spgeom0))) + def inv(A): lu, piv, info = getrf(A, overwrite_a=True) if info == 0: x, info = getri(lu, piv, lwork=lwork, overwrite_lu=True) if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not compute the inverse.") + raise ValueError( + f"{self.__class__.__name__}.green could not compute the inverse." + ) return x if tile == 1: @@ -1018,14 +1093,24 @@ def inv(A): if self.parent.orthogonal: # Orthogonal *always* identity S0E = eye(len(M0), dtype=dtype) * E + def _calc_green(k, dtype, no, tile, idx0): SL, SR = SE(E, k, dtype=dtype, **kwargs) - return inv(S0E - M0Pk(k, dtype=dtype, format="array", **kwargs) - SL - SR) + return inv( + S0E - M0Pk(k, dtype=dtype, format="array", **kwargs) - SL - SR + ) + else: M0Sk = M0.Sk + def _calc_green(k, dtype, no, tile, idx0): SL, SR = SE(E, k, dtype=dtype, **kwargs) - return inv(M0Sk(k, dtype=dtype, format="array") * E - M0Pk(k, dtype=dtype, format="array", **kwargs) - SL - SR) + return inv( + M0Sk(k, dtype=dtype, format="array") * E + - M0Pk(k, dtype=dtype, format="array", **kwargs) + - SL + - SR + ) else: # Get faster methods since we don't want overhead of solve @@ -1033,20 +1118,29 @@ def _calc_green(k, dtype, no, tile, idx0): M1 = self._calc["SE"].spgeom1 M1Pk = M1.Pk if self.parent.orthogonal: + def _calc_green(k, dtype, no, tile, idx0): # Calculate left/right self-energies - Gf, A2 = SE(E, k, dtype=dtype, bulk=True, **kwargs) # A1 == Gf, because of memory usage + Gf, A2 = SE( + E, k, dtype=dtype, bulk=True, **kwargs + ) # A1 == Gf, because of memory usage # skip negation since we don't do negation on tY/tX B = M1Pk(k, dtype=dtype, format="array", **kwargs) # C = conjugate(B.T) - _, _, tY, info = gesv(Gf, conjugate(B.T), overwrite_a=True, overwrite_b=True) + _, _, tY, info = gesv( + Gf, conjugate(B.T), overwrite_a=True, overwrite_b=True + ) if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not solve tY x = B system!") + raise ValueError( + f"{self.__class__.__name__}.green could not solve tY x = B system!" + ) Gf[:, :] = inv(A2 - matmul(B, tY)) _, _, tX, info = gesv(A2, B, overwrite_a=True, overwrite_b=True) if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not solve tX x = B system!") + raise ValueError( + f"{self.__class__.__name__}.green could not solve tX x = B system!" + ) # Since this is the pristine case, we know that # G11 and G22 are the same: @@ -1055,33 +1149,55 @@ def _calc_green(k, dtype, no, tile, idx0): G = empty([tile, no, tile, no], dtype=dtype) G[idx0, :, idx0, :] = Gf.reshape(1, no, no) for i in range(1, tile): - G[idx0[i:], :, idx0[:-i], :] = matmul(tX, G[i-1, :, 0, :]).reshape(1, no, no) - G[idx0[:-i], :, idx0[i:], :] = matmul(tY, G[0, :, i-1, :]).reshape(1, no, no) + G[idx0[i:], :, idx0[:-i], :] = matmul( + tX, G[i - 1, :, 0, :] + ).reshape(1, no, no) + G[idx0[:-i], :, idx0[i:], :] = matmul( + tY, G[0, :, i - 1, :] + ).reshape(1, no, no) return G.reshape(tile * no, -1) else: M1Sk = M1.Sk + def _calc_green(k, dtype, no, tile, idx0): - Gf, A2 = SE(E, k, dtype=dtype, bulk=True, **kwargs) # A1 == Gf, because of memory usage - tY = M1Sk(k, dtype=dtype, format="array") # S - tX = M1Pk(k, dtype=dtype, format="array", **kwargs) # H + Gf, A2 = SE( + E, k, dtype=dtype, bulk=True, **kwargs + ) # A1 == Gf, because of memory usage + tY = M1Sk(k, dtype=dtype, format="array") # S + tX = M1Pk(k, dtype=dtype, format="array", **kwargs) # H # negate B to allow faster gesv method B = tX - tY * E # C = _conj(tY.T) * E - _conj(tX.T) - _, _, tY[:, :], info = gesv(Gf, conjugate(tX.T) - conjugate(tY.T) * E, overwrite_a=True, overwrite_b=True) + _, _, tY[:, :], info = gesv( + Gf, + conjugate(tX.T) - conjugate(tY.T) * E, + overwrite_a=True, + overwrite_b=True, + ) if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not solve tY x = B system!") + raise ValueError( + f"{self.__class__.__name__}.green could not solve tY x = B system!" + ) Gf[:, :] = inv(A2 - matmul(B, tY)) - _, _, tX[:, :], info = gesv(A2, B, overwrite_a=True, overwrite_b=True) + _, _, tX[:, :], info = gesv( + A2, B, overwrite_a=True, overwrite_b=True + ) if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not solve tX x = B system!") + raise ValueError( + f"{self.__class__.__name__}.green could not solve tX x = B system!" + ) G = empty([tile, no, tile, no], dtype=dtype) G[idx0, :, idx0, :] = Gf.reshape(1, no, no) for i in range(1, tile): - G[idx0[i:], :, idx0[:-i], :] = matmul(tX, G[i-1, :, 0, :]).reshape(1, no, no) - G[idx0[:-i], :, idx0[i:], :] = matmul(tY, G[0, :, i-1, :]).reshape(1, no, no) + G[idx0[i:], :, idx0[:-i], :] = matmul( + tX, G[i - 1, :, 0, :] + ).reshape(1, no, no) + G[idx0[:-i], :, idx0[i:], :] = matmul( + tY, G[0, :, i - 1, :] + ).reshape(1, no, no) return G.reshape(tile * no, -1) # Create functions used to calculate the real-space Green function @@ -1092,8 +1208,10 @@ def _calc_green(k, dtype, no, tile, idx0): # If using Bloch's theorem we need to wrap the Green function calculation # as the method call. if len(bloch) > 1: + def _func_bloch(k, dtype, no, tile, idx0): return bloch(_calc_green, k, dtype=dtype, no=no, tile=tile, idx0=idx0) + else: _func_bloch = _calc_green @@ -1114,13 +1232,13 @@ def _func_bloch(k, dtype, no, tile, idx0): return G def clear(self): - """ Clears the internal arrays created in `initialize` """ + """Clears the internal arrays created in `initialize`""" del self._calc @set_module("sisl.physics") class RealSpaceSI(SelfEnergy): - r""" Surface real-space self-energy (or Green function) for a given physical object with limited periodicity + r"""Surface real-space self-energy (or Green function) for a given physical object with limited periodicity The surface real-space self-energy is calculated via the k-averaged Green function: @@ -1190,34 +1308,44 @@ class RealSpaceSI(SelfEnergy): """ def __init__(self, semi, surface, k_axes, unfold=(1, 1, 1), **options): - """ Initialize real-space self-energy calculator """ + """Initialize real-space self-energy calculator""" self.semi = semi self.surface = surface if not self.semi.lattice.parallel(surface.lattice): - raise ValueError(f"{self.__class__.__name__} requires semi and surface to have parallel " - "lattice vectors.") + raise ValueError( + f"{self.__class__.__name__} requires semi and surface to have parallel " + "lattice vectors." + ) self._k_axes = np.sort(_a.arrayi(k_axes).ravel()) k_ax = self._k_axes if self.semi.semi_inf in k_ax: - raise ValueError(f"{self.__class__.__name__} found the self-energy direction to be " - "the same as one of the k-axes, this is not allowed.") + raise ValueError( + f"{self.__class__.__name__} found the self-energy direction to be " + "the same as one of the k-axes, this is not allowed." + ) # Local variables for the completion of the details self._unfold = _a.arrayi([max(1, un) for un in unfold]) if self.surface.nsc[semi.semi_inf] > 1: - raise ValueError(f"{self.__class__.__name__} surface has periodicity along the semi-infinite " - "direction. This is not allowed.") + raise ValueError( + f"{self.__class__.__name__} surface has periodicity along the semi-infinite " + "direction. This is not allowed." + ) if np.any(self.surface.nsc[k_ax] < 3): - raise ValueError(f"{self.__class__.__name__} found k-axes without periodicity. " - "Correct `k_axes` via `.set_option`.") + raise ValueError( + f"{self.__class__.__name__} found k-axes without periodicity. " + "Correct `k_axes` via `.set_option`." + ) if self._unfold[semi.semi_inf] > 1: - raise ValueError(f"{self.__class__.__name__} cannot unfold along the semi-infinite direction. " - "This is a surface real-space self-energy.") + raise ValueError( + f"{self.__class__.__name__} cannot unfold along the semi-infinite direction. " + "This is a surface real-space self-energy." + ) # Now we need to figure out the atoms in the surface that corresponds to the # semi-infinite direction. @@ -1245,8 +1373,10 @@ def __init__(self, semi, surface, k_axes, unfold=(1, 1, 1), **options): if not np.allclose(self.semi.geometry.xyz, g_surf, rtol=0, atol=1e-3): print("Coordinate difference:") print(self.semi.geometry.xyz - g_surf) - raise ValueError(f"{self.__class__.__name__} overlapping semi-infinite " - "and surface atoms does not coincide!") + raise ValueError( + f"{self.__class__.__name__} overlapping semi-infinite " + "and surface atoms does not coincide!" + ) # Surface orbitals to put in the semi-infinite self-energy into. orbs = self.surface.geometry.a2o(atoms, True) @@ -1280,7 +1410,7 @@ def __len__(self): return len(self.surface) * np.prod(self._unfold) def __str__(self): - """ String representation of RealSpaceSI """ + """String representation of RealSpaceSI""" d = {"class": self.__class__.__name__} for i in range(3): d[f"u{i}"] = self._unfold[i] @@ -1289,16 +1419,18 @@ def __str__(self): d["surface"] = str(self.surface).replace("\n", "\n ") d["bz"] = str(self._options["bz"]).replace("\n", "\n ") d["trs"] = str(self._options["trs"]) - return ("{class}{{unfold: [{u0}, {u1}, {u2}],\n " - "k-axes: {k}, trs: {trs},\n " - "bz: {bz},\n " - "semi-infinite:\n" - " bulk: {self._options['semi_bulk']},\n" - " {semi},\n " - "surface:\n {surface}\n}}").format(**d) + return ( + "{class}{{unfold: [{u0}, {u1}, {u2}],\n " + "k-axes: {k}, trs: {trs},\n " + "bz: {bz},\n " + "semi-infinite:\n" + " bulk: {self._options['semi_bulk']},\n" + " {semi},\n " + "surface:\n {surface}\n}}" + ).format(**d) def set_options(self, **options): - r""" Update options in the real-space self-energy + r"""Update options in the real-space self-energy After updating options one should re-call `initialize` for consistency. @@ -1322,7 +1454,7 @@ def set_options(self, **options): self._options.update(options) def real_space_parent(self): - r""" Fully expanded real-space surface parent + r"""Fully expanded real-space surface parent Notes ----- @@ -1342,7 +1474,7 @@ def real_space_parent(self): return P0 def real_space_coupling(self, ret_indices=False): - r""" Real-space coupling surface where the outside fold into the surface real-space unit cell + r"""Real-space coupling surface where the outside fold into the surface real-space unit cell The resulting parent object only contains the inner-cell couplings for the elements that couple out of the real-space matrix. @@ -1383,7 +1515,9 @@ def real_space_coupling(self, ret_indices=False): # out of the self-energy used. The k-axis retains all atoms, per see. nsc = array_replace(PC_k.nsc, (k_ax, None), (self.semi.semi_inf, None), other=1) PC_k.set_nsc(nsc) - nsc = array_replace(PC_semi.nsc, (k_ax, None), (self.semi.semi_inf, None), other=1) + nsc = array_replace( + PC_semi.nsc, (k_ax, None), (self.semi.semi_inf, None), other=1 + ) PC_semi.set_nsc(nsc) nsc = array_replace(PC.nsc, (k_ax, None), other=1) PC.set_nsc(nsc) @@ -1398,7 +1532,9 @@ def get_connections(PC, nrep=1, na=0, na_off=0): n = PC.shape[0] idx = g.lattice.sc_index([0, 0, 0]) cols = _a.arangei(idx * n, (idx + 1) * n) - csr = PC._csr.copy([0]) # we just want the sparse pattern, so forget about the other elements + csr = PC._csr.copy( + [0] + ) # we just want the sparse pattern, so forget about the other elements csr.delete_columns(cols, keep_shape=True) # Now PC only contains couplings along the k and semi-inf directions # Extract the connecting orbitals and reduce them to unique atomic indices @@ -1422,8 +1558,12 @@ def expand(atom, nrep, na, na_off): if len(PC_semi.edges(atom)) > 0: atom_semi.append(atom) atom_semi = _a.arrayi(atom_semi) - expand(atom_semi, n_unfold, self.semi.spgeom1.geometry.na, self.surface.geometry.na) - atom_k = get_connections(PC_k, n_unfold, self.semi.spgeom0.geometry.na, self.surface.geometry.na) + expand( + atom_semi, n_unfold, self.semi.spgeom1.geometry.na, self.surface.geometry.na + ) + atom_k = get_connections( + PC_k, n_unfold, self.semi.spgeom0.geometry.na, self.surface.geometry.na + ) if self.semi.semi_inf_dir == 1: # we are dealing with *right* scheme, so last atoms. # Shift coordinates by the offset @@ -1454,7 +1594,7 @@ def expand(atom, nrep, na, na_off): return PC def initialize(self): - r""" Initialize the internal data-arrays used for efficient calculation of the real-space quantities + r"""Initialize the internal data-arrays used for efficient calculation of the real-space quantities This method should first be called *after* all options has been specified. @@ -1493,8 +1633,10 @@ def initialize(self): nk[self._k_axes] = np.ceil(self._options["dk"] * rcell).astype(np.int32) self._options["bz"] = MonkhorstPack(lattice, nk, trs=self._options["trs"]) - def self_energy(self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, **kwargs): - r""" Calculate real-space surface self-energy + def self_energy( + self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, **kwargs + ): + r"""Calculate real-space surface self-energy The real space self-energy is calculated via: @@ -1531,10 +1673,24 @@ def self_energy(self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, ** if coupling: orbs = self._calc["orbs"] iorbs = delete(_a.arangei(len(G)), orbs).reshape(-1, 1) - SeH = self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs) + SeH = self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"]( + k, dtype=dtype, **kwargs + ) if bulk: - return solve(G[orbs, orbs.T], eye(orbs.size, dtype=dtype) - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), True, True) - return SeH[orbs, orbs.T].toarray() - solve(G[orbs, orbs.T], eye(orbs.size, dtype=dtype) - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), True, True) + return solve( + G[orbs, orbs.T], + eye(orbs.size, dtype=dtype) + - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), + True, + True, + ) + return SeH[orbs, orbs.T].toarray() - solve( + G[orbs, orbs.T], + eye(orbs.size, dtype=dtype) + - matmul(G[orbs, iorbs.T], SeH[iorbs, orbs.T].toarray()), + True, + True, + ) # Another way to do the coupling calculation would be the *full* thing # which should always be slower. @@ -1542,22 +1698,25 @@ def self_energy(self, E, k=(0, 0, 0), bulk=False, coupling=False, dtype=None, ** # since comparing the two yields numerical differences on the order 1e-8 eV depending # on the size of the full matrix G. - #orbs = self._calc["orbs"] - #iorbs = _a.arangei(orbs.size).reshape(1, -1) - #I = zeros([G.shape[0], orbs.size], dtype) + # orbs = self._calc["orbs"] + # iorbs = _a.arangei(orbs.size).reshape(1, -1) + # I = zeros([G.shape[0], orbs.size], dtype) # Set diagonal - #I[orbs.ravel(), iorbs.ravel()] = 1. - #if bulk: + # I[orbs.ravel(), iorbs.ravel()] = 1. + # if bulk: # return solve(G, I, True, True)[orbs, iorbs] - #return (self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs))[orbs, orbs.T].toarray() \ + # return (self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs))[orbs, orbs.T].toarray() \ # - solve(G, I, True, True)[orbs, iorbs] if bulk: return inv(G, True) - return (self._calc["S0"](k, dtype=dtype) * E - self._calc["P0"](k, dtype=dtype, **kwargs)).toarray() - inv(G, True) + return ( + self._calc["S0"](k, dtype=dtype) * E + - self._calc["P0"](k, dtype=dtype, **kwargs) + ).toarray() - inv(G, True) def green(self, E, k=(0, 0, 0), dtype=None, **kwargs): - r""" Calculate the real-space Green function + r"""Calculate the real-space Green function The real space Green function is calculated via: @@ -1597,13 +1756,17 @@ def green(self, E, k=(0, 0, 0), dtype=None, **kwargs): k_ax = self._k_axes k = _a.asarrayd(k) - is_k = np.any(k != 0.) + is_k = np.any(k != 0.0) if is_k: axes = [self.semi.semi_inf] + k_ax.tolist() - if np.any(k[axes] != 0.): - raise ValueError(f"{self.__class__.__name__}.green requires k-point to be zero along the integrated axes.") + if np.any(k[axes] != 0.0): + raise ValueError( + f"{self.__class__.__name__}.green requires k-point to be zero along the integrated axes." + ) if trs: - raise ValueError(f"{self.__class__.__name__}.green requires a k-point sampled Green function to not use time reversal symmetry.") + raise ValueError( + f"{self.__class__.__name__}.green requires a k-point sampled Green function to not use time reversal symmetry." + ) # Shift k-points to get the correct k-point in the larger one. bz._k += k.reshape(1, 3) @@ -1617,32 +1780,48 @@ def green(self, E, k=(0, 0, 0), dtype=None, **kwargs): getri = linalg_info("getri", dtype) getri_lwork = linalg_info("getri_lwork", dtype) lwork = int(1.01 * _compute_lwork(getri_lwork, len(M0))) + def inv(A): lu, piv, info = getrf(A, overwrite_a=True) if info == 0: x, info = getri(lu, piv, lwork=lwork, overwrite_lu=True) if info != 0: - raise ValueError(f"{self.__class__.__name__}.green could not compute the inverse.") + raise ValueError( + f"{self.__class__.__name__}.green could not compute the inverse." + ) return x if M0.orthogonal: # Orthogonal *always* identity S0E = eye(len(M0), dtype=dtype) * E + def _calc_green(k, dtype, surf_orbs, semi_bulk): invG = S0E - M0Pk(k, dtype=dtype, format="array", **kwargs) if semi_bulk: - invG[surf_orbs, surf_orbs.T] = SE(E, k, dtype=dtype, bulk=semi_bulk, **kwargs) + invG[surf_orbs, surf_orbs.T] = SE( + E, k, dtype=dtype, bulk=semi_bulk, **kwargs + ) else: - invG[surf_orbs, surf_orbs.T] -= SE(E, k, dtype=dtype, bulk=semi_bulk, **kwargs) + invG[surf_orbs, surf_orbs.T] -= SE( + E, k, dtype=dtype, bulk=semi_bulk, **kwargs + ) return inv(invG) + else: M0Sk = M0.Sk + def _calc_green(k, dtype, surf_orbs, semi_bulk): - invG = M0Sk(k, dtype=dtype, format="array") * E - M0Pk(k, dtype=dtype, format="array", **kwargs) + invG = M0Sk(k, dtype=dtype, format="array") * E - M0Pk( + k, dtype=dtype, format="array", **kwargs + ) if semi_bulk: - invG[surf_orbs, surf_orbs.T] = SE(E, k, dtype=dtype, bulk=semi_bulk, **kwargs) + invG[surf_orbs, surf_orbs.T] = SE( + E, k, dtype=dtype, bulk=semi_bulk, **kwargs + ) else: - invG[surf_orbs, surf_orbs.T] -= SE(E, k, dtype=dtype, bulk=semi_bulk, **kwargs) + invG[surf_orbs, surf_orbs.T] -= SE( + E, k, dtype=dtype, bulk=semi_bulk, **kwargs + ) return inv(invG) # Create functions used to calculate the real-space Green function @@ -1656,15 +1835,23 @@ def _calc_green(k, dtype, surf_orbs, semi_bulk): # If using Bloch's theorem we need to wrap the Green function calculation # as the method call. if len(bloch) > 1: + def _func_bloch(k, dtype, surf_orbs, semi_bulk): - return bloch(_calc_green, k, dtype=dtype, surf_orbs=surf_orbs, semi_bulk=semi_bulk) + return bloch( + _calc_green, + k, + dtype=dtype, + surf_orbs=surf_orbs, + semi_bulk=semi_bulk, + ) + else: _func_bloch = _calc_green # calculate the Green function - G = bz.apply.average(_func_bloch)(dtype=dtype, - surf_orbs=self._surface_orbs, - semi_bulk=opt["semi_bulk"]) + G = bz.apply.average(_func_bloch)( + dtype=dtype, surf_orbs=self._surface_orbs, semi_bulk=opt["semi_bulk"] + ) if is_k: # Restore Brillouin zone k-points @@ -1676,5 +1863,5 @@ def _func_bloch(k, dtype, surf_orbs, semi_bulk): return G def clear(self): - """ Clears the internal arrays created in `initialize` """ + """Clears the internal arrays created in `initialize`""" del self._calc diff --git a/src/sisl/physics/sparse.py b/src/sisl/physics/sparse.py index 1cbc7ba329..92b214f1f8 100644 --- a/src/sisl/physics/sparse.py +++ b/src/sisl/physics/sparse.py @@ -27,7 +27,7 @@ @set_module("sisl.physics") class SparseOrbitalBZ(SparseOrbital): - r""" Sparse object containing the orbital connections in a Brillouin zone + r"""Sparse object containing the orbital connections in a Brillouin zone It contains an intrinsic sparse matrix of the physical elements. @@ -72,7 +72,7 @@ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): self._reset() def _reset(self): - r""" Reset object according to the options, please refer to `SparseOrbital.reset` for details """ + r"""Reset object according to the options, please refer to `SparseOrbital.reset` for details""" if self.orthogonal: self.Sk = self._Sk_diagonal self.S_idx = -100 @@ -93,20 +93,20 @@ def _cls_kwargs(self): @property def orthogonal(self): - r""" True if the object is using an orthogonal basis """ + r"""True if the object is using an orthogonal basis""" return self._orthogonal @property def non_orthogonal(self): - r""" True if the object is using a non-orthogonal basis """ + r"""True if the object is using a non-orthogonal basis""" return not self._orthogonal def __len__(self): - r""" Returns number of rows in the basis (if non-collinear or spin-orbit, twice the number of orbitals) """ + r"""Returns number of rows in the basis (if non-collinear or spin-orbit, twice the number of orbitals)""" return self.no def __str__(self): - r""" Representation of the model """ + r"""Representation of the model""" s = f"{self.__class__.__name__}{{dim: {self.dim}, non-zero: {self.nnz}, orthogonal: {self.orthogonal}\n " return s + str(self.geometry).replace("\n", "\n ") + "\n}" @@ -116,7 +116,7 @@ def __repr__(self): @property def S(self): - r""" Access the overlap elements associated with the sparse matrix """ + r"""Access the overlap elements associated with the sparse matrix""" if self.orthogonal: return None self._def_dim = self.S_idx @@ -124,7 +124,7 @@ def S(self): @classmethod def fromsp(cls, geometry, P, S=None, **kwargs): - r""" Create a sparse model from a preset `Geometry` and a list of sparse matrices + r"""Create a sparse model from a preset `Geometry` and a list of sparse matrices The passed sparse matrices are in one of `scipy.sparse` formats. @@ -163,13 +163,15 @@ def fromsp(cls, geometry, P, S=None, **kwargs): p._csr = p._csr.fromsp(*P, dtype=kwargs.get("dtype")) if p._size != P[0].shape[0]: - raise ValueError(f"{cls.__name__}.fromsp cannot create a new class, the geometry " - "and sparse matrices does not have coinciding dimensions size != P[0].shape[0]") + raise ValueError( + f"{cls.__name__}.fromsp cannot create a new class, the geometry " + "and sparse matrices does not have coinciding dimensions size != P[0].shape[0]" + ) return p def iter_orbitals(self, atoms=None, local=False): - r""" Iterations of the orbital space in the geometry, two indices from loop + r"""Iterations of the orbital space in the geometry, two indices from loop An iterator returning the current atomic index and the corresponding orbital index. @@ -202,7 +204,7 @@ def iter_orbitals(self, atoms=None, local=False): yield from self.geometry.iter_orbitals(atoms=atoms, local=local) def _Pk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", _dim=0): - r""" Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a polarized system + r"""Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a polarized system Parameters ---------- @@ -217,7 +219,7 @@ def _Pk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", _dim=0): return matrix_k(gauge, self, _dim, self.lattice, k, dtype, format) def _dPk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", _dim=0): - r""" Sparse matrix (``scipy.sparse.csr_matrix``) at `k` differentiated with respect to `k` for a polarized system + r"""Sparse matrix (``scipy.sparse.csr_matrix``) at `k` differentiated with respect to `k` for a polarized system Parameters ---------- @@ -232,7 +234,7 @@ def _dPk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", _dim=0): return matrix_dk(gauge, self, _dim, self.lattice, k, dtype, format) def _ddPk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", _dim=0): - r""" Sparse matrix (``scipy.sparse.csr_matrix``) at `k` double differentiated with respect to `k` for a polarized system + r"""Sparse matrix (``scipy.sparse.csr_matrix``) at `k` double differentiated with respect to `k` for a polarized system Parameters ---------- @@ -246,8 +248,10 @@ def _ddPk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", _dim=0): k = _a.asarrayd(k).ravel() return matrix_ddk(gauge, self, _dim, self.lattice, k, dtype, format) - def Sk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): # pylint: disable=E0202 - r""" Setup the overlap matrix for a given k-point + def Sk( + self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs + ): # pylint: disable=E0202 + r"""Setup the overlap matrix for a given k-point Creation and return of the overlap matrix for a given k-point (default to Gamma). @@ -299,8 +303,10 @@ def Sk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): """ pass - def _Sk_diagonal(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): - r""" For an orthogonal case we always return the identity matrix """ + def _Sk_diagonal( + self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs + ): + r"""For an orthogonal case we always return the identity matrix""" if dtype is None: dtype = np.float64 nr = len(self) @@ -317,14 +323,14 @@ def _Sk_diagonal(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, # TODO if format in ("array", "matrix", "dense"): S = np.zeros([nr, nc], dtype=dtype) - np.fill_diagonal(S, 1.) + np.fill_diagonal(S, 1.0) return S S = csr_matrix((nr, nc), dtype=dtype) - S.setdiag(1.) + S.setdiag(1.0) return S.asformat(format) def _Sk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k`. + r"""Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k`. Parameters ---------- @@ -338,7 +344,7 @@ def _Sk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return self._Pk(k, dtype=dtype, gauge=gauge, format=format, _dim=self.S_idx) def dSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): - r""" Setup the :math:`k`-derivatie of the overlap matrix for a given k-point + r"""Setup the :math:`k`-derivatie of the overlap matrix for a given k-point Creation and return of the derivative of the overlap matrix for a given k-point (default to Gamma). @@ -389,7 +395,7 @@ def dSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs) pass def _dSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` differentiated with respect to `k` + r"""Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` differentiated with respect to `k` Parameters ---------- @@ -403,7 +409,7 @@ def _dSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return self._dPk(k, dtype=dtype, gauge=gauge, format=format, _dim=self.S_idx) def _dSk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` for non-collinear spin, differentiated with respect to `k` + r"""Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` for non-collinear spin, differentiated with respect to `k` Parameters ---------- @@ -415,10 +421,14 @@ def _dSk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): chosen gauge """ k = _a.asarrayd(k).ravel() - return matrix_dk_nc_diag(gauge, self, self.S_idx, self.lattice, k, dtype, format) + return matrix_dk_nc_diag( + gauge, self, self.S_idx, self.lattice, k, dtype, format + ) - def ddSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs): # pylint: disable=E0202 - r""" Setup the double :math:`k`-derivatie of the overlap matrix for a given k-point + def ddSk( + self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs + ): # pylint: disable=E0202 + r"""Setup the double :math:`k`-derivatie of the overlap matrix for a given k-point Creation and return of the double derivative of the overlap matrix for a given k-point (default to Gamma). @@ -469,7 +479,7 @@ def ddSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr", *args, **kwargs pass def _ddSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` double differentiated with respect to `k` + r"""Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` double differentiated with respect to `k` Parameters ---------- @@ -483,7 +493,7 @@ def _ddSk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return self._ddPk(k, dtype=dtype, gauge=gauge, format=format, _dim=self.S_idx) def _ddSk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` for non-collinear spin, differentiated with respect to `k` + r"""Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k` for non-collinear spin, differentiated with respect to `k` Parameters ---------- @@ -495,10 +505,12 @@ def _ddSk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): chosen gauge """ k = _a.asarrayd(k).ravel() - return matrix_ddk_nc_diag(gauge, self, self.S_idx, self.lattice, k, dtype, format) + return matrix_ddk_nc_diag( + gauge, self, self.S_idx, self.lattice, k, dtype, format + ) def eig(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): - r""" Returns the eigenvalues of the physical quantity (using the non-Hermitian solver) + r"""Returns the eigenvalues of the physical quantity (using the non-Hermitian solver) Setup the system and overlap matrix with respect to the given k-point and calculate the eigenvalues. @@ -518,7 +530,7 @@ def eig(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): return lin.eig_destroy(P, S, **kwargs) def eigh(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): - r""" Returns the eigenvalues of the physical quantity + r"""Returns the eigenvalues of the physical quantity Setup the system and overlap matrix with respect to the given k-point and calculate the eigenvalues. @@ -534,7 +546,7 @@ def eigh(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): return lin.eigh_destroy(P, S, eigvals_only=eigvals_only, **kwargs) def eigsh(self, k=(0, 0, 0), n=10, gauge="R", eigvals_only=True, **kwargs): - r""" Calculates a subset of eigenvalues of the physical quantity (default 10) + r"""Calculates a subset of eigenvalues of the physical quantity (default 10) Setup the quantity and overlap matrix with respect to the given k-point and calculate a subset of the eigenvalues using the sparse algorithms. @@ -555,7 +567,7 @@ def eigsh(self, k=(0, 0, 0), n=10, gauge="R", eigvals_only=True, **kwargs): def __getstate__(self): return { "sparseorbitalbz": super().__getstate__(), - "orthogonal": self._orthogonal + "orthogonal": self._orthogonal, } def __setstate__(self, state): @@ -566,7 +578,7 @@ def __setstate__(self, state): @set_module("sisl.physics") class SparseOrbitalBZSpin(SparseOrbitalBZ): - r""" Sparse object containing the orbital connections in a Brillouin zone with possible spin-components + r"""Sparse object containing the orbital connections in a Brillouin zone with possible spin-components It contains an intrinsic sparse matrix of the physical elements. @@ -606,10 +618,12 @@ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): if isinstance(dim, Spin): spin = dim else: - spin = {1: Spin.UNPOLARIZED, - 2: Spin.POLARIZED, - 4: Spin.NONCOLINEAR, - 8: Spin.SPINORBIT}.get(dim) + spin = { + 1: Spin.UNPOLARIZED, + 2: Spin.POLARIZED, + 4: Spin.NONCOLINEAR, + 8: Spin.SPINORBIT, + }.get(dim) else: spin = kwargs.pop("spin") self._spin = Spin(spin, dtype) @@ -618,7 +632,7 @@ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): self._reset() def _reset(self): - r""" Reset object according to the options, please refer to `SparseOrbital.reset` for details """ + r"""Reset object according to the options, please refer to `SparseOrbital.reset` for details""" super()._reset() if self.spin.is_unpolarized: @@ -691,11 +705,11 @@ def _cls_kwargs(self): @property def spin(self): - r""" Associated spin class """ + r"""Associated spin class""" return self._spin def create_construct(self, R, param): - r""" Create a simple function for passing to the `construct` function. + r"""Create a simple function for passing to the `construct` function. This is to relieve the creation of simplistic functions needed for setting up sparse elements. @@ -737,44 +751,55 @@ def create_construct(self, R, param): construct : routine to create the sparse matrix from a generic function (as returned from `create_construct`) """ if len(R) != len(param): - raise ValueError(f"{self.__class__.__name__}.create_construct got different lengths of `R` and `param`") + raise ValueError( + f"{self.__class__.__name__}.create_construct got different lengths of `R` and `param`" + ) if not self.spin.is_diagonal: is_complex = self.dkind == "c" if self.spin.is_spinorbit: if is_complex: nv = 4 # Hermitian parameters - paramH = [[p[0].conj(), p[1].conj(), p[3].conj(), p[2].conj(), *p[4:]] - for p in param] + paramH = [ + [p[0].conj(), p[1].conj(), p[3].conj(), p[2].conj(), *p[4:]] + for p in param + ] else: nv = 8 # Hermitian parameters - paramH = [[p[0], p[1], p[6], -p[7], -p[4], -p[5], p[2], -p[3], *p[8:]] - for p in param] + paramH = [ + [p[0], p[1], p[6], -p[7], -p[4], -p[5], p[2], -p[3], *p[8:]] + for p in param + ] if not self.orthogonal: nv += 1 # ensure we have correct number of values assert all(len(p) == nv for p in param) - if R[0] <= 0.1001: # no atom closer than 0.1001 Ang! + if R[0] <= 0.1001: # no atom closer than 0.1001 Ang! # We check that the the parameters here is Hermitian p = param[0] if is_complex: - onsite = np.array([[p[0], p[2]], - [p[3], p[1]]], self.dtype) + onsite = np.array([[p[0], p[2]], [p[3], p[1]]], self.dtype) else: - onsite = np.array([[p[0] + 1j * p[4], p[2] + 1j * p[3]], - [p[6] + 1j * p[7], p[1] + 1j * p[5]]], np.complex128) + onsite = np.array( + [ + [p[0] + 1j * p[4], p[2] + 1j * p[3]], + [p[6] + 1j * p[7], p[1] + 1j * p[5]], + ], + np.complex128, + ) if not np.allclose(onsite, onsite.T.conj()): - warn(f"{self.__class__.__name__}.create_construct is NOT Hermitian for on-site terms. This is your responsibility!") + warn( + f"{self.__class__.__name__}.create_construct is NOT Hermitian for on-site terms. This is your responsibility!" + ) elif self.spin.is_noncolinear: if is_complex: nv = 3 # Hermitian parameters - paramH = [[p[0].conj(), p[1].conj(), p[2], *p[3:]] - for p in param] + paramH = [[p[0].conj(), p[1].conj(), p[2], *p[3:]] for p in param] else: nv = 4 # Hermitian parameters @@ -809,14 +834,17 @@ def func(self, ia, atoms, atoms_xyz=None): return super().create_construct(R, param) def __len__(self): - r""" Returns number of rows in the basis (if non-collinear or spin-orbit, twice the number of orbitals) """ + r"""Returns number of rows in the basis (if non-collinear or spin-orbit, twice the number of orbitals)""" if self.spin.is_diagonal: return self.no return self.no * 2 def __str__(self): - r""" Representation of the model """ - s = self.__class__.__name__ + f"{{non-zero: {self.nnz}, orthogonal: {self.orthogonal},\n " + r"""Representation of the model""" + s = ( + self.__class__.__name__ + + f"{{non-zero: {self.nnz}, orthogonal: {self.orthogonal},\n " + ) s += str(self.spin).replace("\n", "\n ") + ",\n " s += str(self.geometry).replace("\n", "\n ") return s + "\n}" @@ -827,12 +855,12 @@ def __repr__(self): Spin.UNPOLARIZED: "unpolarized", Spin.POLARIZED: "polarized", Spin.NONCOLINEAR: "noncolinear", - Spin.SPINORBIT: "spinorbit" - }.get(self.spin._kind, f"unkown({self.spin._kind})") + Spin.SPINORBIT: "spinorbit", + }.get(self.spin._kind, f"unkown({self.spin._kind})") return f"<{self.__module__}.{self.__class__.__name__} na={g.na}, no={g.no}, nsc={g.nsc}, dim={self.dim}, nnz={self.nnz}, spin={spin}>" def _Pk_unpolarized(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Sparse matrix (``scipy.sparse.csr_matrix``) at `k` + r"""Sparse matrix (``scipy.sparse.csr_matrix``) at `k` Parameters ---------- @@ -846,7 +874,7 @@ def _Pk_unpolarized(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return self._Pk(k, dtype=dtype, gauge=gauge, format=format) def _Pk_polarized(self, k=(0, 0, 0), spin=0, dtype=None, gauge="R", format="csr"): - r""" Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a polarized system + r"""Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a polarized system Parameters ---------- @@ -862,7 +890,7 @@ def _Pk_polarized(self, k=(0, 0, 0), spin=0, dtype=None, gauge="R", format="csr" return self._Pk(k, dtype=dtype, gauge=gauge, format=format, _dim=spin) def _Pk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system + r"""Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system Parameters ---------- @@ -877,7 +905,7 @@ def _Pk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return matrix_k_nc(gauge, self, self.lattice, k, dtype, format) def _Pk_spin_orbit(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a spin-orbit system + r"""Sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a spin-orbit system Parameters ---------- @@ -892,7 +920,7 @@ def _Pk_spin_orbit(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return matrix_k_so(gauge, self, self.lattice, k, dtype, format) def _dPk_unpolarized(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k`, differentiated with respect to `k` + r"""Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k`, differentiated with respect to `k` Parameters ---------- @@ -906,7 +934,7 @@ def _dPk_unpolarized(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return self._dPk(k, dtype=dtype, gauge=gauge, format=format) def _dPk_polarized(self, k=(0, 0, 0), spin=0, dtype=None, gauge="R", format="csr"): - r""" Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k`, differentiated with respect to `k` + r"""Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k`, differentiated with respect to `k` Parameters ---------- @@ -922,7 +950,7 @@ def _dPk_polarized(self, k=(0, 0, 0), spin=0, dtype=None, gauge="R", format="csr return self._dPk(k, dtype=dtype, gauge=gauge, format=format, _dim=spin) def _dPk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` + r"""Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` Parameters ---------- @@ -937,7 +965,7 @@ def _dPk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return matrix_dk_nc(gauge, self, self.lattice, k, dtype, format) def _dPk_spin_orbit(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` + r"""Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` Parameters ---------- @@ -952,7 +980,7 @@ def _dPk_spin_orbit(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return matrix_dk_so(gauge, self, self.lattice, k, dtype, format) def _ddPk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` twice + r"""Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` twice Parameters ---------- @@ -967,7 +995,7 @@ def _ddPk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return matrix_ddk_nc(gauge, self, self.lattice, k, dtype, format) def _ddPk_spin_orbit(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` + r"""Tuple of sparse matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system, differentiated with respect to `k` Parameters ---------- @@ -982,7 +1010,7 @@ def _ddPk_spin_orbit(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return matrix_ddk_so(gauge, self, self.lattice, k, dtype, format) def _Sk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k`. + r"""Overlap matrix in a ``scipy.sparse.csr_matrix`` at `k`. Parameters ---------- @@ -996,7 +1024,7 @@ def _Sk(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return self._Pk(k, dtype=dtype, gauge=gauge, format=format, _dim=self.S_idx) def _Sk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system + r"""Overlap matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system Parameters ---------- @@ -1011,7 +1039,7 @@ def _Sk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): return matrix_k_nc_diag(gauge, self, self.S_idx, self.lattice, k, dtype, format) def _dSk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): - r""" Overlap matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system + r"""Overlap matrix (``scipy.sparse.csr_matrix``) at `k` for a non-collinear system Parameters ---------- @@ -1023,10 +1051,12 @@ def _dSk_non_colinear(self, k=(0, 0, 0), dtype=None, gauge="R", format="csr"): chosen gauge """ k = _a.asarrayd(k).ravel() - return matrix_dk_nc_diag(gauge, self, self.S_idx, self.lattice, k, dtype, format) + return matrix_dk_nc_diag( + gauge, self, self.S_idx, self.lattice, k, dtype, format + ) def eig(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): - r""" Returns the eigenvalues of the physical quantity (using the non-Hermitian solver) + r"""Returns the eigenvalues of the physical quantity (using the non-Hermitian solver) Setup the system and overlap matrix with respect to the given k-point and calculate the eigenvalues. @@ -1058,7 +1088,7 @@ def eig(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): return lin.eig_destroy(P, S, **kwargs) def eigh(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): - r""" Returns the eigenvalues of the physical quantity + r"""Returns the eigenvalues of the physical quantity Setup the system and overlap matrix with respect to the given k-point and calculate the eigenvalues. @@ -1086,7 +1116,7 @@ def eigh(self, k=(0, 0, 0), gauge="R", eigvals_only=True, **kwargs): return lin.eigh_destroy(P, S, eigvals_only=eigvals_only, **kwargs) def eigsh(self, k=(0, 0, 0), n=10, gauge="R", eigvals_only=True, **kwargs): - r""" Calculates a subset of eigenvalues of the physical quantity (default 10) + r"""Calculates a subset of eigenvalues of the physical quantity (default 10) Setup the quantity and overlap matrix with respect to the given k-point and calculate a subset of the eigenvalues using the sparse algorithms. @@ -1114,7 +1144,7 @@ def eigsh(self, k=(0, 0, 0), n=10, gauge="R", eigvals_only=True, **kwargs): return lin.eigsh(P, M=S, k=n, return_eigenvectors=not eigvals_only, **kwargs) def transpose(self, hermitian=False, spin=True, sort=True): - r""" A transpose copy of this object, possibly apply the Hermitian conjugate as well + r"""A transpose copy of this object, possibly apply the Hermitian conjugate as well Parameters ---------- @@ -1148,7 +1178,7 @@ def transpose(self, hermitian=False, spin=True, sort=True): if sp.dkind == "f": # imaginary components # 12,11,22,21 - D[:, [3, 4, 5, 7]] *= -1. + D[:, [3, 4, 5, 7]] *= -1.0 else: D[:, :] = np.conj(D[:, :]) elif spin: @@ -1160,7 +1190,7 @@ def transpose(self, hermitian=False, spin=True, sort=True): elif sp.is_noncolinear: if hermitian and spin: - pass # do nothing, it is already ensured Hermitian + pass # do nothing, it is already ensured Hermitian elif hermitian or spin: # conjugate the imaginary value # since for transposing D[:, 3] is the same @@ -1178,7 +1208,7 @@ def transpose(self, hermitian=False, spin=True, sort=True): return new def trs(self): - r""" Create a new matrix with applied time-reversal-symmetry + r"""Create a new matrix with applied time-reversal-symmetry Time reversal symmetry is applied using the following equality: @@ -1214,7 +1244,7 @@ def trs(self): return new def transform(self, matrix=None, dtype=None, spin=None, orthogonal=None): - r""" Transform the matrix by either a matrix or new spin configuration + r"""Transform the matrix by either a matrix or new spin configuration 1. General transformation: * If `matrix` is provided, a linear transformation :math:`R^n \rightarrow R^m` is applied @@ -1291,41 +1321,44 @@ def transform(self, matrix=None, dtype=None, spin=None, orthogonal=None): matrix[:m, :n] = np.eye(m, n, dtype=dtype) if not self.orthogonal and not orthogonal: # ensure the overlap matrix is carried over - matrix[-1, -1] = 1. + matrix[-1, -1] = 1.0 if spin.is_unpolarized and self.spin.size > 1: # average up and down components matrix[0, [0, 1]] = 0.5 elif spin.size > 1 and self.spin.is_unpolarized: # set up and down components to unpolarized value - matrix[[0, 1], 0] = 1. + matrix[[0, 1], 0] = 1.0 else: # convert to numpy array matrix = np.asarray(matrix) - if (M != m and matrix.shape[0] == m and - N != n and matrix.shape[1] == n): + if M != m and matrix.shape[0] == m and N != n and matrix.shape[1] == n: # this means that the user wants to preserve the overlap matrix_full = np.zeros([M, N], dtype=dtype) matrix_full[:m, :n] = matrix - matrix_full[-1, -1] = 1. + matrix_full[-1, -1] = 1.0 matrix = matrix_full if matrix.shape[0] != M or matrix.shape[1] != N: # while this check also occurs in the SparseCSR.transform # code, but the error message is better placed here. - raise ValueError(f"{self.__class__.__name__}.transform incompatible " - f"transformation matrix and spin dimensions: " - f"matrix.shape={matrix.shape} and self.spin={N} ; out.spin={M}") - - new = self.__class__(self.geometry.copy(), spin=spin, dtype=dtype, nnzpr=1, orthogonal=orthogonal) + raise ValueError( + f"{self.__class__.__name__}.transform incompatible " + f"transformation matrix and spin dimensions: " + f"matrix.shape={matrix.shape} and self.spin={N} ; out.spin={M}" + ) + + new = self.__class__( + self.geometry.copy(), spin=spin, dtype=dtype, nnzpr=1, orthogonal=orthogonal + ) new._csr = self._csr.transform(matrix, dtype=dtype) if not orthogonal and self.orthogonal: # set identity overlap matrix, loop over rows for i in range(new._csr.shape[0]): - new._csr[i, i, -1] = 1. + new._csr[i, i, -1] = 1.0 return new diff --git a/src/sisl/physics/spin.py b/src/sisl/physics/spin.py index cb40d20e4d..fec607e1ff 100644 --- a/src/sisl/physics/spin.py +++ b/src/sisl/physics/spin.py @@ -5,12 +5,12 @@ from sisl._internal import set_module -__all__ = ['Spin'] +__all__ = ["Spin"] @set_module("sisl.physics") class Spin: - r""" Spin class to determine configurations and spin components. + r"""Spin class to determine configurations and spin components. The basic class `Spin` implements a generic method to determine a spin configuration. @@ -57,7 +57,6 @@ class Spin: __slots__ = ("_size", "_kind", "_dtype") def __init__(self, kind="", dtype=None): - if isinstance(kind, Spin): if dtype is None: dtype = kind._dtype @@ -75,37 +74,50 @@ def __init__(self, kind="", dtype=None): if isinstance(kind, str): kind = kind.lower() - kind = {"unpolarized": Spin.UNPOLARIZED, "": Spin.UNPOLARIZED, - Spin.UNPOLARIZED: Spin.UNPOLARIZED, - "polarized": Spin.POLARIZED, "p": Spin.POLARIZED, - "pol": Spin.POLARIZED, - Spin.POLARIZED: Spin.POLARIZED, - "noncolinear": Spin.NONCOLINEAR, - "noncollinear": Spin.NONCOLINEAR, - "non-colinear": Spin.NONCOLINEAR, - "non-collinear": Spin.NONCOLINEAR, "nc": Spin.NONCOLINEAR, - Spin.NONCOLINEAR: Spin.NONCOLINEAR, - "spinorbit": Spin.SPINORBIT, - "spin-orbit": Spin.SPINORBIT, "so": Spin.SPINORBIT, - "soc": Spin.SPINORBIT, Spin.SPINORBIT: Spin.SPINORBIT}.get(kind) + kind = { + "unpolarized": Spin.UNPOLARIZED, + "": Spin.UNPOLARIZED, + Spin.UNPOLARIZED: Spin.UNPOLARIZED, + "polarized": Spin.POLARIZED, + "p": Spin.POLARIZED, + "pol": Spin.POLARIZED, + Spin.POLARIZED: Spin.POLARIZED, + "noncolinear": Spin.NONCOLINEAR, + "noncollinear": Spin.NONCOLINEAR, + "non-colinear": Spin.NONCOLINEAR, + "non-collinear": Spin.NONCOLINEAR, + "nc": Spin.NONCOLINEAR, + Spin.NONCOLINEAR: Spin.NONCOLINEAR, + "spinorbit": Spin.SPINORBIT, + "spin-orbit": Spin.SPINORBIT, + "so": Spin.SPINORBIT, + "soc": Spin.SPINORBIT, + Spin.SPINORBIT: Spin.SPINORBIT, + }.get(kind) if kind is None: - raise ValueError(f"{self.__class__.__name__} initialization went wrong because of wrong " - "kind specification. Could not determine the kind of spin!") + raise ValueError( + f"{self.__class__.__name__} initialization went wrong because of wrong " + "kind specification. Could not determine the kind of spin!" + ) # Now assert the checks self._kind = kind if np.dtype(dtype).kind == "c": - size = {self.UNPOLARIZED: 1, - self.POLARIZED: 2, - self.NONCOLINEAR: 4, - self.SPINORBIT: 4}.get(kind) + size = { + self.UNPOLARIZED: 1, + self.POLARIZED: 2, + self.NONCOLINEAR: 4, + self.SPINORBIT: 4, + }.get(kind) else: - size = {self.UNPOLARIZED: 1, - self.POLARIZED: 2, - self.NONCOLINEAR: 4, - self.SPINORBIT: 8}.get(kind) + size = { + self.UNPOLARIZED: 1, + self.POLARIZED: 2, + self.NONCOLINEAR: 4, + self.SPINORBIT: 8, + }.get(kind) self._size = size @@ -119,55 +131,55 @@ def __str__(self): return f"{self.__class__.__name__}{{spin-orbit, kind={self.dkind}}}" def copy(self): - """ Create a copy of the spin-object """ + """Create a copy of the spin-object""" return Spin(self.kind, self.dtype) @property def dtype(self): - """ Data-type of the spin configuration """ + """Data-type of the spin configuration""" return self._dtype @property def dkind(self): - """ Data-type kind """ + """Data-type kind""" return np.dtype(self._dtype).kind @property def size(self): - """ Number of elements to describe the spin-components """ + """Number of elements to describe the spin-components""" return self._size @property def spinor(self): - """ Number of spinor components (1 or 2) """ + """Number of spinor components (1 or 2)""" return min(2, self._size) @property def kind(self): - """ A unique ID for the kind of spin configuration """ + """A unique ID for the kind of spin configuration""" return self._kind @property def is_unpolarized(self): - """ True if the configuration is not polarized """ + """True if the configuration is not polarized""" # Regardless of data-type return self.kind == Spin.UNPOLARIZED @property def is_polarized(self): - """ True if the configuration is polarized """ + """True if the configuration is polarized""" return self.kind == Spin.POLARIZED is_colinear = is_polarized @property def is_noncolinear(self): - """ True if the configuration non-collinear """ + """True if the configuration non-collinear""" return self.kind == Spin.NONCOLINEAR @property def is_diagonal(self): - """ Whether the spin-box is only using the diagonal components + """Whether the spin-box is only using the diagonal components This will return true for non-polarized and polarized spin configurations. Otherwise false. @@ -176,7 +188,7 @@ def is_diagonal(self): @property def is_spinorbit(self): - """ True if the configuration is spin-orbit """ + """True if the configuration is spin-orbit""" return self.kind == Spin.SPINORBIT def __len__(self): @@ -202,11 +214,7 @@ def __ge__(self, other): return self.kind >= other.kind def __getstate__(self): - return { - "size": self.size, - "kind": self.kind, - "dtype": self.dtype - } + return {"size": self.size, "kind": self.kind, "dtype": self.dtype} def __setstate__(self, state): self._size = state["size"] diff --git a/src/sisl/physics/state.py b/src/sisl/physics/state.py index 12f6a8c087..3259454040 100644 --- a/src/sisl/physics/state.py +++ b/src/sisl/physics/state.py @@ -13,7 +13,7 @@ from sisl.linalg import eigh_destroy from sisl.messages import warn -__all__ = ['degenerate_decouple', 'Coefficient', 'State', 'StateC'] +__all__ = ["degenerate_decouple", "Coefficient", "State", "StateC"] _pi = np.pi _pi2 = np.pi * 2 @@ -56,8 +56,9 @@ def degenerate_decouple(state, M): class _FakeMatrix: - """ Replacement object which superseedes a matrix """ - __slots__ = ('n', 'm') + """Replacement object which superseedes a matrix""" + + __slots__ = ("n", "m") ndim = 2 def __init__(self, n, m=None): @@ -92,8 +93,9 @@ def T(self): @set_module("sisl.physics") class ParentContainer: - """ A container for parent and information """ - __slots__ = ['parent', 'info'] + """A container for parent and information""" + + __slots__ = ["parent", "info"] def __init__(self, parent, **info): self.parent = parent @@ -107,7 +109,7 @@ def __init__(self, parent, **info): @singledispatchmethod def _sanitize_index(self, idx): - r""" Ensure indices are transferred to acceptable integers """ + r"""Ensure indices are transferred to acceptable integers""" if idx is None: # in case __len__ is not defined, this will fail... return np.arange(len(self)) @@ -127,7 +129,7 @@ def _(self, idx): @set_module("sisl.physics") class Coefficient(ParentContainer): - """ An object holding coefficients for a parent with info + """An object holding coefficients for a parent with info Parameters ---------- @@ -139,7 +141,8 @@ class Coefficient(ParentContainer): an info dictionary that turns into an attribute on the object. This `info` may contain anything that may be relevant for the coefficient. """ - __slots__ = ['c'] + + __slots__ = ["c"] def __init__(self, c, parent=None, **info): super().__init__(parent, **info) @@ -149,41 +152,41 @@ def __init__(self, c, parent=None, **info): """ def __str__(self): - """ The string representation of this object """ + """The string representation of this object""" s = f"{self.__class__.__name__}{{coefficients: {len(self)}, kind: {self.dkind}" if self.parent is None: - s += '}}' + s += "}}" else: - s += ',\n {}\n}}'.format(str(self.parent).replace('\n', '\n ')) + s += ",\n {}\n}}".format(str(self.parent).replace("\n", "\n ")) return s def __len__(self): - """ Number of coefficients """ + """Number of coefficients""" return self.shape[0] @property def dtype(self): - """ Data-type for the coefficients """ + """Data-type for the coefficients""" return self.c.dtype @property def dkind(self): - """ The data-type of the coefficient (in str) """ + """The data-type of the coefficient (in str)""" return np.dtype(self.c.dtype).kind @property def shape(self): - """ Returns the shape of the coefficients """ + """Returns the shape of the coefficients""" return self.c.shape def copy(self): - """ Return a copy (only the coefficients are copied). ``parent`` and ``info`` are passed by reference """ + """Return a copy (only the coefficients are copied). ``parent`` and ``info`` are passed by reference""" copy = self.__class__(self.c.copy(), self.parent) copy.info = self.info return copy def degenerate(self, eps=1e-8): - """ Find degenerate coefficients with a specified precision + """Find degenerate coefficients with a specified precision Parameters ---------- @@ -213,7 +216,7 @@ def degenerate(self, eps=1e-8): return deg def sub(self, idx, inplace=False): - """ Return a new coefficient with only the specified coefficients + """Return a new coefficient with only the specified coefficients Parameters ---------- @@ -236,7 +239,7 @@ def sub(self, idx, inplace=False): return sub def remove(self, idx, inplace=False): - """ Return a new coefficient without the specified coefficients + """Return a new coefficient without the specified coefficients Parameters ---------- @@ -254,7 +257,7 @@ def remove(self, idx, inplace=False): return self.sub(idx, inplace) def __getitem__(self, key): - """ Return a new coefficient object with only one associated coefficient + """Return a new coefficient object with only one associated coefficient Parameters ---------- @@ -269,7 +272,7 @@ def __getitem__(self, key): return self.sub(key) def iter(self, asarray=False): - """ An iterator looping over the coefficients in this system + """An iterator looping over the coefficients in this system Parameters ---------- @@ -296,7 +299,7 @@ def iter(self, asarray=False): @set_module("sisl.physics") class State(ParentContainer): - """ An object handling a set of vectors describing a given *state* + """An object handling a set of vectors describing a given *state* Parameters ---------- @@ -312,10 +315,11 @@ class State(ParentContainer): ----- This class should be subclassed! """ - __slots__ = ['state'] + + __slots__ = ["state"] def __init__(self, state, parent=None, **info): - """ Define a state container with a given set of states """ + """Define a state container with a given set of states""" super().__init__(parent, **info) self.state = np.atleast_2d(state) """ numpy.ndarray @@ -323,41 +327,41 @@ def __init__(self, state, parent=None, **info): """ def __str__(self): - """ The string representation of this object """ + """The string representation of this object""" s = f"{self.__class__.__name__}{{states: {len(self)}, kind: {self.dkind}" if self.parent is None: - s += '}' + s += "}" else: - s += ',\n {}\n}}'.format(str(self.parent).replace('\n', '\n ')) + s += ",\n {}\n}}".format(str(self.parent).replace("\n", "\n ")) return s def __len__(self): - """ Number of states """ + """Number of states""" return self.shape[0] @property def dtype(self): - """ Data-type for the state """ + """Data-type for the state""" return self.state.dtype @property def dkind(self): - """ The data-type of the state (in str) """ + """The data-type of the state (in str)""" return np.dtype(self.state.dtype).kind @property def shape(self): - """ Returns the shape of the state """ + """Returns the shape of the state""" return self.state.shape def copy(self): - """ Return a copy (only the state is copied). ``parent`` and ``info`` are passed by reference """ + """Return a copy (only the state is copied). ``parent`` and ``info`` are passed by reference""" copy = self.__class__(self.state.copy(), self.parent) copy.info = self.info return copy def sub(self, idx, inplace=False): - """ Return a new state with only the specified states + """Return a new state with only the specified states Parameters ---------- @@ -380,7 +384,7 @@ def sub(self, idx, inplace=False): return sub def remove(self, idx, inplace=False): - """ Return a new state without the specified vectors + """Return a new state without the specified vectors Parameters ---------- @@ -398,7 +402,7 @@ def remove(self, idx, inplace=False): return self.sub(idx, inplace) def translate(self, isc): - r""" Translate the vectors to a new unit-cell position + r"""Translate the vectors to a new unit-cell position The method is thoroughly explained in `tile` while this one only selects the corresponding state vector @@ -413,7 +417,7 @@ def translate(self, isc): tile : equivalent method for generating more cells simultaneously """ # the k-point gets reduced - k = _a.asarrayd(self.info.get("k", [0]*3)) + k = _a.asarrayd(self.info.get("k", [0] * 3)) assert len(isc) == 3 s = self.copy() @@ -423,7 +427,7 @@ def translate(self, isc): # i * a_0 + j * a_1 + k * a_2 if not np.allclose(k, 0): # there will only be a phase if k != 0 - s.state *= exp(2j*_pi * k @ isc) + s.state *= exp(2j * _pi * k @ isc) return s def tile(self, reps, axis, normalize=False, offset=0): @@ -459,7 +463,7 @@ def tile(self, reps, axis, normalize=False, offset=0): # the parent gets tiled parent = self.parent.tile(reps, axis) # the k-point gets reduced - k = _a.asarrayd(self.info.get("k", [0]*3)) + k = _a.asarrayd(self.info.get("k", [0] * 3)) # now tile the state vectors state = np.tile(self.state, (1, reps)).astype(np.complex128, copy=False) @@ -472,14 +476,14 @@ def tile(self, reps, axis, normalize=False, offset=0): # with T being # i * a_0 + j * a_1 + k * a_2 # We can leave out the lattice vectors entirely - phase = exp(2j*_pi * k[axis] * (_a.aranged(reps) + offset)) + phase = exp(2j * _pi * k[axis] * (_a.aranged(reps) + offset)) state *= phase.reshape(1, -1, 1) state.shape = (len(self), -1) # update new k; when we double the system, we halve the periodicity # and hence we need to account for this - k[axis] = (k[axis] * reps % 1) + k[axis] = k[axis] * reps % 1 while k[axis] > 0.5: k[axis] -= 1 while k[axis] <= -0.5: @@ -491,14 +495,14 @@ def tile(self, reps, axis, normalize=False, offset=0): s.state = state # update the k-point s.info = dict(**self.info) - s.info.update({'k': k}) + s.info.update({"k": k}) if normalize: return s.normalize() return s def __getitem__(self, key): - """ Return a new state with only one associated state + """Return a new state with only one associated state Parameters ---------- @@ -513,7 +517,7 @@ def __getitem__(self, key): return self.sub(key) def iter(self, asarray=False): - """ An iterator looping over the states in this system + """An iterator looping over the states in this system Parameters ---------- @@ -538,7 +542,7 @@ def iter(self, asarray=False): __iter__ = iter def norm(self): - r""" Return a vector with the Euclidean norm of each state :math:`\sqrt{\langle\psi|\psi\rangle}` + r"""Return a vector with the Euclidean norm of each state :math:`\sqrt{\langle\psi|\psi\rangle}` Returns ------- @@ -548,7 +552,7 @@ def norm(self): return self.norm2() ** 0.5 def norm2(self, sum=True): - r""" Return a vector with the norm of each state :math:`\langle\psi|\psi\rangle` + r"""Return a vector with the norm of each state :math:`\langle\psi|\psi\rangle` Parameters ---------- @@ -608,10 +612,10 @@ def ipr(self, q=2): state_abs2 = self.norm2(sum=False).real assert q >= 2, f"{self.__class__.__name__}.ipr requires q>=2" # abs2 is already having the exponent 2 - return (state_abs2 ** q).sum(-1) / state_abs2.sum(-1) ** q + return (state_abs2**q).sum(-1) / state_abs2.sum(-1) ** q def normalize(self): - r""" Return a normalized state where each state has :math:`|\psi|^2=1` + r"""Return a normalized state where each state has :math:`|\psi|^2=1` This is roughly equivalent to: @@ -634,7 +638,7 @@ def normalize(self): return s def outer(self, ket=None, matrix=None): - r""" Return the outer product by :math:`\sum_i|\psi_i\rangle\langle\psi'_i|` + r"""Return the outer product by :math:`\sum_i|\psi_i\rangle\langle\psi'_i|` Parameters ---------- @@ -660,7 +664,9 @@ def outer(self, ket=None, matrix=None): M = matrix ndim = M.ndim if ndim not in (0, 1, 2): - raise ValueError(f"{self.__class__.__name__}.outer only accepts matrices up to 2 dimensions.") + raise ValueError( + f"{self.__class__.__name__}.outer only accepts matrices up to 2 dimensions." + ) bra = self.state # decide on the ket @@ -677,29 +683,33 @@ def outer(self, ket=None, matrix=None): # ket M bra if ndim == 0: # M,N @ N, L - if (ket.shape[1] != bra.shape[1]): - raise ValueError(f"{self.__class__.__name__}.outer requires the objects to have matching shapes bra @ ket bra={self.shape}, ket={ket.shape[::-1]}") + if ket.shape[1] != bra.shape[1]: + raise ValueError( + f"{self.__class__.__name__}.outer requires the objects to have matching shapes bra @ ket bra={self.shape}, ket={ket.shape[::-1]}" + ) elif ndim == 1: # M,N @ N @ N, L - if (ket.shape[0] != M.shape[0] or - M.shape[0] != bra.shape[0]): - raise ValueError(f"{self.__class__.__name__}.outer requires the objects to have matching shapes ket @ M @ bra ket={ket.shape[::-1]}, M={M.shape}, bra={self.shape}") + if ket.shape[0] != M.shape[0] or M.shape[0] != bra.shape[0]: + raise ValueError( + f"{self.__class__.__name__}.outer requires the objects to have matching shapes ket @ M @ bra ket={ket.shape[::-1]}, M={M.shape}, bra={self.shape}" + ) elif ndim == 2: # M,N @ N,K @ K,L - if (ket.shape[0] != M.shape[0] or - M.shape[1] != bra.shape[0]): - raise ValueError(f"{self.__class__.__name__}.outer requires the objects to have matching shapes ket @ M @ bra ket={ket.shape[::-1]}, M={M.shape}, bra={self.shape}") + if ket.shape[0] != M.shape[0] or M.shape[1] != bra.shape[0]: + raise ValueError( + f"{self.__class__.__name__}.outer requires the objects to have matching shapes ket @ M @ bra ket={ket.shape[::-1]}, M={M.shape}, bra={self.shape}" + ) if ndim == 2: Aij = ket.T @ M.dot(np.conj(bra)) elif ndim == 1: - Aij = einsum('ij,i,ik->jk', ket, M, np.conj(bra)) + Aij = einsum("ij,i,ik->jk", ket, M, np.conj(bra)) elif ndim == 0: - Aij = einsum('ij,ik->jk', ket * M, np.conj(bra)) + Aij = einsum("ij,ik->jk", ket * M, np.conj(bra)) return Aij def inner(self, ket=None, matrix=None, diag=True): - r""" Calculate the inner product as :math:`\mathbf A_{ij} = \langle\psi_i|\mathbf M|\psi'_j\rangle` + r"""Calculate the inner product as :math:`\mathbf A_{ij} = \langle\psi_i|\mathbf M|\psi'_j\rangle` Parameters ---------- @@ -732,7 +742,9 @@ def inner(self, ket=None, matrix=None, diag=True): M = matrix ndim = M.ndim if ndim not in (0, 1, 2): - raise ValueError(f"{self.__class__.__name__}.inner only accepts matrices up to 2 dimensions.") + raise ValueError( + f"{self.__class__.__name__}.inner only accepts matrices up to 2 dimensions." + ) bra = self.state # decide on the ket @@ -749,38 +761,44 @@ def inner(self, ket=None, matrix=None, diag=True): # bra M ket if ndim == 0: # M,N @ N, L - if (bra.shape[1] != ket.shape[1]): - raise ValueError(f"{self.__class__.__name__}.inner requires the objects to have matching shapes bra @ ket bra={self.shape}, ket={ket.shape[::-1]}") + if bra.shape[1] != ket.shape[1]: + raise ValueError( + f"{self.__class__.__name__}.inner requires the objects to have matching shapes bra @ ket bra={self.shape}, ket={ket.shape[::-1]}" + ) elif ndim == 1: # M,N @ N @ N, L - if (bra.shape[1] != M.shape[0] or - M.shape[0] != ket.shape[1]): - raise ValueError(f"{self.__class__.__name__}.inner requires the objects to have matching shapes bra @ M @ ket bra={self.shape}, M={M.shape}, ket={ket.shape[::-1]}") + if bra.shape[1] != M.shape[0] or M.shape[0] != ket.shape[1]: + raise ValueError( + f"{self.__class__.__name__}.inner requires the objects to have matching shapes bra @ M @ ket bra={self.shape}, M={M.shape}, ket={ket.shape[::-1]}" + ) elif ndim == 2: # M,N @ N,K @ K,L - if (bra.shape[1] != M.shape[0] or - M.shape[1] != ket.shape[1]): - raise ValueError(f"{self.__class__.__name__}.inner requires the objects to have matching shapes bra @ M @ ket bra={self.shape}, M={M.shape}, ket={ket.shape[::-1]}") + if bra.shape[1] != M.shape[0] or M.shape[1] != ket.shape[1]: + raise ValueError( + f"{self.__class__.__name__}.inner requires the objects to have matching shapes bra @ M @ ket bra={self.shape}, M={M.shape}, ket={ket.shape[::-1]}" + ) if diag: if bra.shape[0] != ket.shape[0]: - raise ValueError(f"{self.__class__.__name__}.inner diagonal matrix product is non-square, please use diag=False or reduce number of vectors.") + raise ValueError( + f"{self.__class__.__name__}.inner diagonal matrix product is non-square, please use diag=False or reduce number of vectors." + ) if ndim == 2: - Aij = einsum('ij,ji->i', np.conj(bra), M.dot(ket.T)) + Aij = einsum("ij,ji->i", np.conj(bra), M.dot(ket.T)) elif ndim == 1: - Aij = einsum('ij,j,ij->i', np.conj(bra), M, ket) + Aij = einsum("ij,j,ij->i", np.conj(bra), M, ket) elif ndim == 0: - Aij = einsum('ij,ij->i', np.conj(bra), ket) * M + Aij = einsum("ij,ij->i", np.conj(bra), ket) * M elif ndim == 2: Aij = np.conj(bra) @ M.dot(ket.T) elif ndim == 1: - Aij = einsum('ij,j,kj->ik', np.conj(bra), M, ket) + Aij = einsum("ij,j,kj->ik", np.conj(bra), M, ket) elif ndim == 0: - Aij = einsum('ij,kj->ik', np.conj(bra), ket) * M + Aij = einsum("ij,kj->ik", np.conj(bra), ket) * M return Aij - def phase(self, method='max', ret_index=False): - r""" Calculate the Euler angle (phase) for the elements of the state, in the range :math:`]-\pi;\pi]` + def phase(self, method="max", ret_index=False): + r"""Calculate the Euler angle (phase) for the elements of the state, in the range :math:`]-\pi;\pi]` Parameters ---------- @@ -790,17 +808,19 @@ def phase(self, method='max', ret_index=False): ret_index : bool, optional return indices for the elements used when ``method=='max'`` """ - if method == 'max': + if method == "max": idx = np.argmax(np.absolute(self.state), 1) if ret_index: return np.angle(self.state[:, idx]), idx return np.angle(self.state[:, idx]) - elif method == 'all': + elif method == "all": return np.angle(self.state) - raise ValueError(f"{self.__class__.__name__}.phase only accepts method in [max, all]") + raise ValueError( + f"{self.__class__.__name__}.phase only accepts method in [max, all]" + ) def align_phase(self, other, ret_index=False, inplace=False): - r""" Align `self` with the phases for `other`, a copy may be returned + r"""Align `self` with the phases for `other`, a copy may be returned States will be rotated by :math:`\pi` provided the phase difference between the states are above :math:`|\Delta\theta| > \pi/2`. @@ -839,7 +859,7 @@ def align_phase(self, other, ret_index=False, inplace=False): return out def align_norm(self, other, ret_index=False, inplace=False): - r""" Align `self` with the site-norms of `other`, a copy may optionally be returned + r"""Align `self` with the site-norms of `other`, a copy may optionally be returned To determine the new ordering of `self` first calculate the residual norm of the site-norms. @@ -883,7 +903,7 @@ def align_norm(self, other, ret_index=False, inplace=False): oidx = _a.emptyi(len(self)) for i in range(len(self)): R = snorm[i] - onorm - R = einsum('ij,ij->i', R, R) + R = einsum("ij,ij->i", R, R) # Figure out which band it should correspond to # find closest largest one @@ -895,7 +915,9 @@ def align_norm(self, other, ret_index=False, inplace=False): show_warn = True if show_warn: - warn(f"{self.__class__.__name__}.align_norm found multiple possible candidates with minimal residue, swapping not unique") + warn( + f"{self.__class__.__name__}.align_norm found multiple possible candidates with minimal residue, swapping not unique" + ) if inplace: self.sub(oidx, inplace=True) @@ -906,8 +928,8 @@ def align_norm(self, other, ret_index=False, inplace=False): else: return self.sub(oidx) - def rotate(self, phi=0., individual=False): - r""" Rotate all states (in-place) to rotate the largest component to be along the angle `phi` + def rotate(self, phi=0.0, individual=False): + r"""Rotate all states (in-place) to rotate the largest component to be along the angle `phi` The states will be rotated according to: @@ -939,7 +961,7 @@ def rotate(self, phi=0., individual=False): s *= phi * np.conj(s[idx] / np.absolute(s[idx])) def change_gauge(self, gauge, offset=(0, 0, 0)): - r""" In-place change of the gauge of the state coefficients + r"""In-place change of the gauge of the state coefficients The two gauges are related through: @@ -958,15 +980,15 @@ def change_gauge(self, gauge, offset=(0, 0, 0)): """ # These calls will fail if the gauge is not specified. # In that case it will not do anything - if self.info.get('gauge', gauge) == gauge: + if self.info.get("gauge", gauge) == gauge: # Quick return return # Update gauge value - self.info['gauge'] = gauge + self.info["gauge"] = gauge # Check that we can do a gauge transformation - k = _a.asarrayd(self.info.get('k', [0., 0., 0.])) + k = _a.asarrayd(self.info.get("k", [0.0, 0.0, 0.0])) if k.dot(k) <= 0.000000001: return @@ -981,10 +1003,10 @@ def change_gauge(self, gauge, offset=(0, 0, 0)): except Exception: pass - if gauge == 'r': + if gauge == "r": # R -> r gauge tranformation. self.state *= exp(-1j * phase).reshape(1, -1) - elif gauge == 'R': + elif gauge == "R": # r -> R gauge tranformation. self.state *= exp(1j * phase).reshape(1, -1) @@ -1038,7 +1060,7 @@ def change_gauge(self, gauge, offset=(0, 0, 0)): # I.e. we are forced to do *one* inheritance, which we choose to be State. @set_module("sisl.physics") class StateC(State): - """ An object handling a set of vectors describing a given *state* with associated coefficients `c` + """An object handling a set of vectors describing a given *state* with associated coefficients `c` Parameters ---------- @@ -1056,10 +1078,11 @@ class StateC(State): ----- This class should be subclassed! """ - __slots__ = ['c'] + + __slots__ = ["c"] def __init__(self, state, c, parent=None, **info): - """ Define a state container with a given set of states and coefficients for the states """ + """Define a state container with a given set of states and coefficients for the states""" super().__init__(state, parent, **info) self.c = np.atleast_1d(c) """ numpy.ndarray @@ -1067,17 +1090,19 @@ def __init__(self, state, c, parent=None, **info): """ if len(self.c) != len(self.state): - raise ValueError(f"{self.__class__.__name__} could not be created with coefficients and states " - "having unequal length.") + raise ValueError( + f"{self.__class__.__name__} could not be created with coefficients and states " + "having unequal length." + ) def copy(self): - """ Return a copy (only the coefficients and states are copied), ``parent`` and ``info`` are passed by reference """ + """Return a copy (only the coefficients and states are copied), ``parent`` and ``info`` are passed by reference""" copy = self.__class__(self.state.copy(), self.c.copy(), self.parent) copy.info = self.info return copy def normalize(self): - r""" Return a normalized state where each state has :math:`|\psi|^2=1` + r"""Return a normalized state where each state has :math:`|\psi|^2=1` This is roughly equivalent to: @@ -1092,12 +1117,14 @@ def normalize(self): a new state with all states normalized, otherwise equal to this """ n = self.norm() - s = self.__class__(self.state / n.reshape(-1, 1), self.c.copy(), parent=self.parent) + s = self.__class__( + self.state / n.reshape(-1, 1), self.c.copy(), parent=self.parent + ) s.info = self.info return s def sort(self, ascending=True): - """ Sort and return a new `StateC` by sorting the coefficients (default to ascending) + """Sort and return a new `StateC` by sorting the coefficients (default to ascending) Parameters ---------- @@ -1110,8 +1137,10 @@ def sort(self, ascending=True): idx = np.argsort(-self.c) return self.sub(idx) - def derivative(self, order=1, degenerate=1e-5, degenerate_dir=(1, 1, 1), matrix=False): - r""" Calculate the derivative with respect to :math:`\mathbf k` for a set of states up to a given order + def derivative( + self, order=1, degenerate=1e-5, degenerate_dir=(1, 1, 1), matrix=False + ): + r"""Calculate the derivative with respect to :math:`\mathbf k` for a set of states up to a given order These are calculated using the analytic expression (:math:`\alpha` corresponding to the Cartesian directions), here only shown for the 1st order derivative: @@ -1187,7 +1216,9 @@ def derivative(self, order=1, degenerate=1e-5, degenerate_dir=(1, 1, 1), matrix= degenerate = self.degenerate(degenerate) if order not in (1, 2): - raise NotImplementedError(f"{self.__class__.__name__}.derivative required order to be in this list: [1, 2], higher order derivatives are not implemented") + raise NotImplementedError( + f"{self.__class__.__name__}.derivative required order to be in this list: [1, 2], higher order derivatives are not implemented" + ) def add_keys(opt, *keys): for key in keys: @@ -1210,7 +1241,8 @@ def add_keys(opt, *keys): dSk = self.parent.dSk(**opt) if order > 1: ddSk = self.parent.ddSk(**opt) - except Exception: pass + except Exception: + pass # Now figure out if spin is a thing add_keys(opt, "spin") @@ -1227,12 +1259,12 @@ def add_keys(opt, *keys): if degenerate is not None: # normalize direction degenerate_dir = _a.asarrayd(degenerate_dir) - degenerate_dir /= (degenerate_dir ** 2).sum() ** 0.5 + degenerate_dir /= (degenerate_dir**2).sum() ** 0.5 # de-coupling is only done for the 1st derivative # create the degeneracy decoupling projector - deg_dPk = sum(d*dh for d, dh in zip(degenerate_dir, dPk)) + deg_dPk = sum(d * dh for d, dh in zip(degenerate_dir, dPk)) if is_orthogonal: for deg in degenerate: @@ -1246,7 +1278,7 @@ def add_keys(opt, *keys): for deg in degenerate: e = np.average(energy[deg]) energy[deg] = e - deg_dSk = sum((d*e)*ds for d, ds in zip(degenerate_dir, dSk)) + deg_dSk = sum((d * e) * ds for d, ds in zip(degenerate_dir, dSk)) state[deg] = degenerate_decouple(state[deg], deg_dPk - deg_dSk) # States have been decoupled and we can calculate things now @@ -1258,7 +1290,6 @@ def add_keys(opt, *keys): # We split everything up into orthogonal and non-orthogonal # This reduces if-checks if is_orthogonal: - if matrix or order > 1: # calculate the full matrix v = np.empty([3, nstate, nstate], dtype=opt["dtype"]) @@ -1305,11 +1336,17 @@ def add_keys(opt, *keys): # zz vv[2, s] = cstate @ ddPk[2].dot(state[s]) - de * absv[2] ** 2 # yz - vv[3, s] = cstate @ ddPk[3].dot(state[s]) - de * absv[1] * absv[2] + vv[3, s] = ( + cstate @ ddPk[3].dot(state[s]) - de * absv[1] * absv[2] + ) # xz - vv[4, s] = cstate @ ddPk[4].dot(state[s]) - de * absv[0] * absv[2] + vv[4, s] = ( + cstate @ ddPk[4].dot(state[s]) - de * absv[0] * absv[2] + ) # xy - vv[5, s] = cstate @ ddPk[5].dot(state[s]) - de * absv[0] * absv[1] + vv[5, s] = ( + cstate @ ddPk[5].dot(state[s]) - de * absv[0] * absv[1] + ) else: vv = np.empty([6, nstate], dtype=opt["dtype"]) @@ -1329,11 +1366,17 @@ def add_keys(opt, *keys): # zz vv[2, s] = cstate[s] @ ddPk[2].dot(state[s]) - de @ absv[2] ** 2 # yz - vv[3, s] = cstate[s] @ ddPk[3].dot(state[s]) - de @ (absv[1] * absv[2]) + vv[3, s] = cstate[s] @ ddPk[3].dot(state[s]) - de @ ( + absv[1] * absv[2] + ) # xz - vv[4, s] = cstate[s] @ ddPk[4].dot(state[s]) - de @ (absv[0] * absv[2]) + vv[4, s] = cstate[s] @ ddPk[4].dot(state[s]) - de @ ( + absv[0] * absv[2] + ) # xy - vv[5, s] = cstate[s] @ ddPk[5].dot(state[s]) - de @ (absv[0] * absv[1]) + vv[5, s] = cstate[s] @ ddPk[5].dot(state[s]) - de @ ( + absv[0] * absv[1] + ) ret += (vv,) else: @@ -1380,17 +1423,35 @@ def add_keys(opt, *keys): # calculate 2nd derivative # xx - vv[0, s] = cstate @ (ddPk[0] - e * ddSk[0]).dot(state[s]) - de * absv[0] ** 2 + vv[0, s] = ( + cstate @ (ddPk[0] - e * ddSk[0]).dot(state[s]) + - de * absv[0] ** 2 + ) # yy - vv[1, s] = cstate @ (ddPk[1] - e * ddSk[1]).dot(state[s]) - de * absv[1] ** 2 + vv[1, s] = ( + cstate @ (ddPk[1] - e * ddSk[1]).dot(state[s]) + - de * absv[1] ** 2 + ) # zz - vv[2, s] = cstate @ (ddPk[2] - e * ddSk[2]).dot(state[s]) - de * absv[2] ** 2 + vv[2, s] = ( + cstate @ (ddPk[2] - e * ddSk[2]).dot(state[s]) + - de * absv[2] ** 2 + ) # yz - vv[3, s] = cstate @ (ddPk[3] - e * ddSk[3]).dot(state[s]) - de * absv[1] * absv[2] + vv[3, s] = ( + cstate @ (ddPk[3] - e * ddSk[3]).dot(state[s]) + - de * absv[1] * absv[2] + ) # xz - vv[4, s] = cstate @ (ddPk[4] - e * ddSk[4]).dot(state[s]) - de * absv[0] * absv[2] + vv[4, s] = ( + cstate @ (ddPk[4] - e * ddSk[4]).dot(state[s]) + - de * absv[0] * absv[2] + ) # xy - vv[5, s] = cstate @ (ddPk[5] - e * ddSk[5]).dot(state[s]) - de * absv[0] * absv[1] + vv[5, s] = ( + cstate @ (ddPk[5] - e * ddSk[5]).dot(state[s]) + - de * absv[0] * absv[1] + ) else: vv = np.empty([6, nstate], dtype=opt["dtype"]) @@ -1404,17 +1465,32 @@ def add_keys(opt, *keys): # calculate 2nd derivative # xx - vv[0, s] = cstate[s] @ (ddPk[0] - e * ddSk[0]).dot(state[s]) - de @ absv[0] ** 2 + vv[0, s] = ( + cstate[s] @ (ddPk[0] - e * ddSk[0]).dot(state[s]) + - de @ absv[0] ** 2 + ) # yy - vv[1, s] = cstate[s] @ (ddPk[1] - e * ddSk[1]).dot(state[s]) - de @ absv[1] ** 2 + vv[1, s] = ( + cstate[s] @ (ddPk[1] - e * ddSk[1]).dot(state[s]) + - de @ absv[1] ** 2 + ) # zz - vv[2, s] = cstate[s] @ (ddPk[2] - e * ddSk[2]).dot(state[s]) - de @ absv[2] ** 2 + vv[2, s] = ( + cstate[s] @ (ddPk[2] - e * ddSk[2]).dot(state[s]) + - de @ absv[2] ** 2 + ) # yz - vv[3, s] = cstate[s] @ (ddPk[3] - e * ddSk[3]).dot(state[s]) - de @ (absv[1] * absv[2]) + vv[3, s] = cstate[s] @ (ddPk[3] - e * ddSk[3]).dot( + state[s] + ) - de @ (absv[1] * absv[2]) # xz - vv[4, s] = cstate[s] @ (ddPk[4] - e * ddSk[4]).dot(state[s]) - de @ (absv[0] * absv[2]) + vv[4, s] = cstate[s] @ (ddPk[4] - e * ddSk[4]).dot( + state[s] + ) - de @ (absv[0] * absv[2]) # xy - vv[5, s] = cstate[s] @ (ddPk[5] - e * ddSk[5]).dot(state[s]) - de @ (absv[0] * absv[1]) + vv[5, s] = cstate[s] @ (ddPk[5] - e * ddSk[5]).dot( + state[s] + ) - de @ (absv[0] * absv[1]) ret += (vv,) @@ -1423,7 +1499,7 @@ def add_keys(opt, *keys): return ret def degenerate(self, eps): - """ Find degenerate coefficients with a specified precision + """Find degenerate coefficients with a specified precision Parameters ---------- @@ -1453,7 +1529,7 @@ def degenerate(self, eps): return deg def sub(self, idx, inplace=False): - """ Return a new state with only the specified states + """Return a new state with only the specified states Parameters ---------- @@ -1477,7 +1553,7 @@ def sub(self, idx, inplace=False): return sub def remove(self, idx, inplace=False): - """ Return a new state without the specified indices + """Return a new state without the specified indices Parameters ---------- diff --git a/src/sisl/physics/tests/test_bloch.py b/src/sisl/physics/tests/test_bloch.py index 979f7177e0..a4dd716b94 100644 --- a/src/sisl/physics/tests/test_bloch.py +++ b/src/sisl/physics/tests/test_bloch.py @@ -12,7 +12,7 @@ def get_H(): - s = geom.sc(1., Atom(1, 1.001)) + s = geom.sc(1.0, Atom(1, 1.001)) H = Hamiltonian(s) H.construct([(0.1, 1.001), (0.1, 0.5)]) return H.tile(2, 0).tile(2, 1).tile(2, 2) @@ -29,7 +29,7 @@ def test_bloch_create(nx, ny, nz): def test_bloch_method(): b = Bloch([1] * 3) - assert 'Bloch' in str(b) + assert "Bloch" in str(b) def test_bloch_call(): @@ -38,11 +38,11 @@ def test_bloch_call(): # Manual k_unfold = b.unfold_points([0] * 3) - m = b.unfold(np.stack([H.Hk(k, format='array') for k in k_unfold]), k_unfold) - m1 = b.unfold([H.Hk(k, format='array') for k in k_unfold], k_unfold) + m = b.unfold(np.stack([H.Hk(k, format="array") for k in k_unfold]), k_unfold) + m1 = b.unfold([H.Hk(k, format="array") for k in k_unfold], k_unfold) assert np.allclose(m, m1) - assert np.allclose(m, b(H.Hk, [0] * 3, format='array')) + assert np.allclose(m, b(H.Hk, [0] * 3, format="array")) @pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) @@ -67,9 +67,9 @@ def test_bloch_H_same(nx, ny, nz, dtype): K = [kx, ky, kz] k_unfold = b.unfold_points(K) - HK = [H.Hk(k, format='array', dtype=dtype) for k in k_unfold] + HK = [H.Hk(k, format="array", dtype=dtype) for k in k_unfold] H_unfold = b.unfold(HK, k_unfold) - H_big = HB.Hk(K, format='array', dtype=dtype) + H_big = HB.Hk(K, format="array", dtype=dtype) assert np.allclose(H_big, H_big.T.conj(), atol=atol) assert np.allclose(H_unfold, H_unfold.T.conj(), atol=atol) diff --git a/src/sisl/physics/tests/test_brillouinzone.py b/src/sisl/physics/tests/test_brillouinzone.py index e63e4bf5aa..83f7170415 100644 --- a/src/sisl/physics/tests/test_brillouinzone.py +++ b/src/sisl/physics/tests/test_brillouinzone.py @@ -24,21 +24,21 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): self.s1 = Lattice(1, nsc=[3, 3, 1]) self.s2 = Lattice([2, 2, 10, 90, 90, 60], [5, 5, 1]) + return t() class TestBrillouinZone: - def setUp(self, setup): setup.s1 = Lattice(1, nsc=[3, 3, 1]) setup.s2 = Lattice([2, 2, 10, 90, 90, 60], [5, 5, 1]) def test_bz1(self, setup): - bz = BrillouinZone(1.) + bz = BrillouinZone(1.0) str(bz) bz.weight bz = BrillouinZone(setup.s1) @@ -50,19 +50,19 @@ def test_bz1(self, setup): for k in bz: assert np.allclose(k, np.zeros(3)) - w = 0. + w = 0.0 for k, wk in bz.iter(True): assert np.allclose(k, np.zeros(3)) w += wk - assert w == pytest.approx(1.) + assert w == pytest.approx(1.0) - bz = BrillouinZone(setup.s1, [[0]*3, [0.5]*3], [.5]*2) + bz = BrillouinZone(setup.s1, [[0] * 3, [0.5] * 3], [0.5] * 2) assert len(bz) == 2 assert len(bz.copy()) == 2 def test_weight_automatic(self, setup): - bz = BrillouinZone(1.) - assert bz.weight[0] == 1. + bz = BrillouinZone(1.0) + assert bz.weight[0] == 1.0 bz = BrillouinZone(setup.s1, np.random.rand(3, 3)) assert bz.weight.sum() == pytest.approx(1) @@ -71,7 +71,7 @@ def test_weight_automatic(self, setup): assert bz.weight.sum() == pytest.approx(1.5) def test_volume_self(self): - bz = BrillouinZone(1.) + bz = BrillouinZone(1.0) assert bz.volume(True)[1] == 0 bz = BrillouinZone(Lattice(1, nsc=[3, 1, 1])) assert bz.volume(True)[1] == 1 @@ -81,7 +81,7 @@ def test_volume_self(self): assert bz.volume(True)[1] == 3 def test_volume_direct(self): - bz = BrillouinZone(1.) + bz = BrillouinZone(1.0) assert bz.volume(True, [0, 1])[1] == 2 assert bz.volume(True, [1])[1] == 1 assert bz.volume(True, [2, 1])[1] == 2 @@ -90,7 +90,7 @@ def test_volume_direct(self): def test_fail(self, setup): with pytest.raises(ValueError): - BrillouinZone(setup.s1, [0] * 3, [.5] * 2) + BrillouinZone(setup.s1, [0] * 3, [0.5] * 2) def test_to_reduced(self, setup): bz = BrillouinZone(setup.s2) @@ -103,24 +103,30 @@ def test_class1(self, setup): class Test(LatticeChild): def __init__(self, lattice): self.set_lattice(lattice) + def eigh(self, k, *args, **kwargs): return np.arange(3) + def eig(self, k, *args, **kwargs): return np.arange(3) - 1 + bz = BrillouinZone(Test(setup.s1)) bz_arr = bz.apply.array str(bz) assert np.allclose(bz_arr.eigh(), np.arange(3)) - assert np.allclose(bz_arr.eig(), np.arange(3)-1) + assert np.allclose(bz_arr.eig(), np.arange(3) - 1) def test_class2(self, setup): class Test(LatticeChild): def __init__(self, lattice): self.set_lattice(lattice) + def eigh(self, k, *args, **kwargs): return np.arange(3) + def eig(self, k, *args, **kwargs): return np.arange(3) - 1 + bz = BrillouinZone(Test(setup.s1)) # Try the list/yield method for val in bz.apply.list.eigh(): @@ -140,29 +146,32 @@ def eig(self, k, *args, **kwargs): def test_parametrize_integer(self, setup): # parametrize for single integers def func(parent, N, i): - return [i/N, 0, 0] + return [i / N, 0, 0] + bz = BrillouinZone.parametrize(setup.s1, func, 10) assert len(bz) == 10 - assert np.allclose(bz.k[-1], [9/10, 0, 0]) + assert np.allclose(bz.k[-1], [9 / 10, 0, 0]) def test_parametrize_list(self, setup): # parametrize for single integers def func(parent, N, i): - return [i[0]/N[0], i[1]/N[1], 0] + return [i[0] / N[0], i[1] / N[1], 0] + bz = BrillouinZone.parametrize(setup.s1, func, [10, 2]) assert len(bz) == 20 - assert np.allclose(bz.k[-1], [9/10, 1/2, 0]) - assert np.allclose(bz.k[-2], [9/10, 0/2, 0]) + assert np.allclose(bz.k[-1], [9 / 10, 1 / 2, 0]) + assert np.allclose(bz.k[-2], [9 / 10, 0 / 2, 0]) def test_default_weight(self): - bz1 = BrillouinZone(geom.graphene(), [[0] * 3, [0.25] * 3], [1/2] * 2) + bz1 = BrillouinZone(geom.graphene(), [[0] * 3, [0.25] * 3], [1 / 2] * 2) bz2 = BrillouinZone(geom.graphene(), [[0] * 3, [0.25] * 3]) assert np.allclose(bz1.k, bz2.k) assert np.allclose(bz1.weight, bz2.weight) def test_pickle(self, setup): import pickle as p - bz1 = BrillouinZone(geom.graphene(), [[0] * 3, [0.25] * 3], [1/2] * 2) + + bz1 = BrillouinZone(geom.graphene(), [[0] * 3, [0.25] * 3], [1 / 2] * 2) n = p.dumps(bz1) bz2 = p.loads(n) assert np.allclose(bz1.k, bz2.k) @@ -171,17 +180,17 @@ def test_pickle(self, setup): @pytest.mark.parametrize("n", [[0, 0, 1], [0.5] * 3]) def test_param_circle(self, n): - bz = BrillouinZone.param_circle(1, 10, 0.1, n, [1/2] * 3) + bz = BrillouinZone.param_circle(1, 10, 0.1, n, [1 / 2] * 3) assert len(bz) == 10 sc = Lattice(1) - bz_loop = BrillouinZone.param_circle(sc, 10, 0.1, n, [1/2] * 3, True) + bz_loop = BrillouinZone.param_circle(sc, 10, 0.1, n, [1 / 2] * 3, True) assert len(bz_loop) == 10 assert not np.allclose(bz.k, bz_loop.k) assert np.allclose(bz_loop.k[0, :], bz_loop.k[-1, :]) def test_merge_simple(self): normal = [0] * 3 - origin = [1/2] * 3 + origin = [1 / 2] * 3 bzs = [ BrillouinZone.param_circle(1, 10, 0.1, normal, origin), @@ -191,14 +200,12 @@ def test_merge_simple(self): bz = BrillouinZone.merge(bzs) assert len(bz) == 30 assert bz.weight.sum() == pytest.approx( - bzs[0].weight.sum() + - bzs[1].weight.sum() + - bzs[2].weight.sum() + bzs[0].weight.sum() + bzs[1].weight.sum() + bzs[2].weight.sum() ) def test_merge_scales(self): normal = [0] * 3 - origin = [1/2] * 3 + origin = [1 / 2] * 3 bzs = [ BrillouinZone.param_circle(1, 10, 0.1, normal, origin), @@ -208,47 +215,41 @@ def test_merge_scales(self): bz = BrillouinZone.merge(bzs, [1, 2, 3]) assert len(bz) == 30 assert bz.weight.sum() == pytest.approx( - bzs[0].weight.sum() + - bzs[1].weight.sum() * 2 + - bzs[2].weight.sum() * 3 + bzs[0].weight.sum() + bzs[1].weight.sum() * 2 + bzs[2].weight.sum() * 3 ) def test_merge_scales_short(self): normal = [0] * 3 - origin = [1/2] * 3 + origin = [1 / 2] * 3 bzs = [ - BrillouinZone.param_circle(1, 10, 0.1, normal, [1/2] * 3), - BrillouinZone.param_circle(1, 10, 0.2, normal, [1/2] * 3), - BrillouinZone.param_circle(1, 10, 0.3, normal, [1/2] * 3), + BrillouinZone.param_circle(1, 10, 0.1, normal, [1 / 2] * 3), + BrillouinZone.param_circle(1, 10, 0.2, normal, [1 / 2] * 3), + BrillouinZone.param_circle(1, 10, 0.3, normal, [1 / 2] * 3), ] bz = BrillouinZone.merge(bzs, [1, 2]) assert len(bz) == 30 assert bz.weight.sum() == pytest.approx( - bzs[0].weight.sum() + - bzs[1].weight.sum() * 2 + - bzs[2].weight.sum() * 2 + bzs[0].weight.sum() + bzs[1].weight.sum() * 2 + bzs[2].weight.sum() * 2 ) def test_merge_scales_scalar(self): normal = [0] * 3 - origin = [1/2] * 3 + origin = [1 / 2] * 3 bzs = [ - BrillouinZone.param_circle(1, 10, 0.1, normal, [1/2] * 3), - BrillouinZone.param_circle(1, 10, 0.3, normal, [1/2] * 3), + BrillouinZone.param_circle(1, 10, 0.1, normal, [1 / 2] * 3), + BrillouinZone.param_circle(1, 10, 0.3, normal, [1 / 2] * 3), ] bz = BrillouinZone.merge(bzs, 1) assert len(bz) == 20 assert bz.weight.sum() == pytest.approx( - bzs[0].weight.sum() + - bzs[1].weight.sum() + bzs[0].weight.sum() + bzs[1].weight.sum() ) @pytest.mark.monkhorstpack class TestMonkhorstPack: - def setUp(self, setup): setup.s1 = Lattice(1, nsc=[3, 3, 1]) setup.s2 = Lattice([2, 2, 10, 90, 90, 60], [5, 5, 1]) @@ -257,10 +258,13 @@ def test_class(self, setup): class Test(LatticeChild): def __init__(self, lattice): self.set_lattice(lattice) + def eigh(self, k, *args, **kwargs): return np.arange(3) + def eig(self, k, *args, **kwargs): return np.arange(3) - 1 + bz = MonkhorstPack(Test(setup.s1), [2] * 3) # Try the yield method bz_yield = bz.apply.iter @@ -273,6 +277,7 @@ def eig(self, k, *args, **kwargs): def test_pickle(self, setup): import pickle as p + bz1 = MonkhorstPack(geom.graphene(), [10, 11, 1], centered=False) n = p.dumps(bz1) bz2 = p.loads(n) @@ -287,8 +292,10 @@ def test_asgrid(self, setup, N, centered): class Test(LatticeChild): def __init__(self, lattice): self.set_lattice(lattice) + def eigh(self, k, *args, **kwargs): return np.arange(3) + bz = MonkhorstPack(Test(setup.s1), [2] * 3).apply.grid # Check the shape @@ -296,7 +303,7 @@ def eigh(self, k, *args, **kwargs): assert np.allclose(grid.shape, [2] * 3) # Check the grids are different - grid2 = bz.eigh(grid_unit='Bohr', wrap=lambda eig: eig[0]) + grid2 = bz.eigh(grid_unit="Bohr", wrap=lambda eig: eig[0]) assert not np.allclose(grid.cell, grid2.cell) assert np.allclose(grid.grid, grid2.grid) @@ -310,8 +317,10 @@ def test_asgrid_fail(self, setup): class Test(LatticeChild): def __init__(self, lattice): self.set_lattice(lattice) + def eigh(self, k, *args, **kwargs): return np.arange(3) + bz = MonkhorstPack(Test(setup.s1), [2] * 3, displacement=[0.1] * 3).apply.grid with pytest.raises(SislError): bz.eigh(wrap=lambda eig: eig[0]) @@ -319,30 +328,31 @@ def eigh(self, k, *args, **kwargs): def test_init_simple(self, setup): bz = MonkhorstPack(setup.s1, [2] * 3, trs=False) assert len(bz) == 8 - assert bz.weight[0] == 1. / 8 + assert bz.weight[0] == 1.0 / 8 def test_displaced(self, setup): bz1 = MonkhorstPack(setup.s1, [2] * 3, centered=False, trs=False) assert len(bz1) == 8 - bz2 = MonkhorstPack(setup.s1, [2] * 3, displacement=[.5] * 3, trs=False) + bz2 = MonkhorstPack(setup.s1, [2] * 3, displacement=[0.5] * 3, trs=False) assert len(bz2) == 8 assert np.allclose(bz1.k, bz2.k) def test_uneven(self, setup): bz1 = MonkhorstPack(setup.s1, [3] * 3, trs=False) - bz2 = MonkhorstPack(setup.s1, [3] * 3, displacement=[.5] * 3, trs=False) + bz2 = MonkhorstPack(setup.s1, [3] * 3, displacement=[0.5] * 3, trs=False) assert not np.allclose(bz1.k, bz2.k) def test_size_half(self, setup): bz1 = MonkhorstPack(setup.s1, [2] * 3, size=0.5, trs=False) assert len(bz1) == 8 assert np.all(bz1.k <= 0.25) - assert bz1.weight.sum() == pytest.approx(0.5 ** 3) + assert bz1.weight.sum() == pytest.approx(0.5**3) def test_as_dataarray(self): pytest.importorskip("xarray", reason="xarray not available") from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) @@ -356,16 +366,16 @@ def test_as_dataarray(self): assert np.allclose(asarray, asdarray.values) assert isinstance(asdarray.bz, MonkhorstPack) assert isinstance(asdarray.parent, Hamiltonian) - assert asdarray.dims == ('k', 'v1') + assert asdarray.dims == ("k", "v1") - asdarray = bz_da.eigh(coords=['orb']) - assert asdarray.dims == ('k', 'orb') + asdarray = bz_da.eigh(coords=["orb"]) + assert asdarray.dims == ("k", "orb") def test_trs(self, setup): size = [0.05, 0.5, 0.9] for x, y, z in product(np.arange(10) + 1, np.arange(20) + 1, np.arange(6) + 1): bz = MonkhorstPack(setup.s1, [x, y, z]) - assert bz.weight.sum() == pytest.approx(1.) + assert bz.weight.sum() == pytest.approx(1.0) bz = MonkhorstPack(setup.s1, [x, y, z], size=size) assert bz.weight.sum() == pytest.approx(np.prod(size)) @@ -373,7 +383,7 @@ def test_gamma_centered(self, setup): for x, y, z in product(np.arange(10) + 1, np.arange(20) + 1, np.arange(6) + 1): bz = MonkhorstPack(setup.s1, [x, y, z], trs=False) assert len(bz) == x * y * z - assert ((bz.k == 0.).sum(1).astype(np.int32) == 3).sum() == 1 + assert ((bz.k == 0.0).sum(1).astype(np.int32) == 3).sum() == 1 def test_gamma_non_centered(self, setup): for x, y, z in product(np.arange(10) + 1, np.arange(20) + 1, np.arange(6) + 1): @@ -385,9 +395,9 @@ def test_gamma_non_centered(self, setup): has_gamma &= y % 2 == 1 has_gamma &= z % 2 == 1 if has_gamma: - assert ((bz.k == 0.).sum(1).astype(np.int32) == 3).sum() == 1 + assert ((bz.k == 0.0).sum(1).astype(np.int32) == 3).sum() == 1 else: - assert ((bz.k == 0.).sum(1).astype(np.int32) == 3).sum() == 0 + assert ((bz.k == 0.0).sum(1).astype(np.int32) == 3).sum() == 0 def test_gamma_centered_displ(self, setup): for x, y, z in product(np.arange(10) + 1, np.arange(20) + 1, np.arange(6) + 1): @@ -396,18 +406,19 @@ def test_gamma_centered_displ(self, setup): k[:, 0] -= 0.2 assert len(bz) == x * y * z if x % 2 == 1: - assert ((k == 0.).sum(1).astype(np.int32) == 3).sum() == 1 + assert ((k == 0.0).sum(1).astype(np.int32) == 3).sum() == 1 else: - assert ((k == 0.).sum(1).astype(np.int32) == 3).sum() == 0 + assert ((k == 0.0).sum(1).astype(np.int32) == 3).sum() == 0 def test_as_simple(self): from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) bz = MonkhorstPack(H, [2, 2, 2], trs=False) - assert len(bz) == 2 ** 3 + assert len(bz) == 2**3 # Assert that as* all does the same apply = bz.apply @@ -425,12 +436,14 @@ def test_as_dataarray_zip(self): pytest.importorskip("xarray", reason="xarray not available") from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) E = np.linspace(-2, 2, 20) bz = MonkhorstPack(H, [2, 2, 1], trs=False) + def wrap(es): return es.eig, es.DOS(E), es.PDOS(E) @@ -438,14 +451,16 @@ def wrap(es): eig, DOS, PDOS = unzip.ndarray.eigenstate(wrap=wrap) ds0 = unzip.dataarray.eigenstate(wrap=wrap, name=["eig", "DOS", "PDOS"]) # explicitly create dimensions - ds1 = unzip.dataarray.eigenstate(wrap=wrap, - coords=[ - {"orb": np.arange(len(H))}, - {"E": E}, - {"spin": [0], "orb": np.arange(len(H)), "E": E}, - ], - dims=(['orb'], ['E'], ['spin', 'orb', 'E']), - name=["eig", "DOS", "PDOS"]) + ds1 = unzip.dataarray.eigenstate( + wrap=wrap, + coords=[ + {"orb": np.arange(len(H))}, + {"E": E}, + {"spin": [0], "orb": np.arange(len(H)), "E": E}, + ], + dims=(["orb"], ["E"], ["spin", "orb", "E"]), + name=["eig", "DOS", "PDOS"], + ) for var, data in zip(["eig", "DOS", "PDOS"], [eig, DOS, PDOS]): assert np.allclose(ds0.data_vars[var].values, data) @@ -453,10 +468,13 @@ def wrap(es): assert len(ds1.coords) < len(ds0.coords) def test_pathos(self): - pytest.skip("BrillouinZone.apply(pool=True|int) scales extremely bad and may cause stall") + pytest.skip( + "BrillouinZone.apply(pool=True|int) scales extremely bad and may cause stall" + ) pytest.importorskip("pathos", reason="pathos not available") from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) @@ -465,8 +483,10 @@ def test_pathos(self): # try and determine a sensible import os + try: import psutil + nprocs = len(psutil.Process().cpu_affinity()) // 2 except Exception: nprocs = os.cpu_count() // 2 @@ -511,6 +531,7 @@ def test_pathos(self): def test_as_single(self): from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) @@ -519,7 +540,7 @@ def wrap(eig): return eig[0] bz = MonkhorstPack(H, [2, 2, 2], trs=False) - assert len(bz) == 2 ** 3 + assert len(bz) == 2**3 # Assert that as* all does the same asarray = bz.apply.array.eigh(wrap=wrap) @@ -530,12 +551,13 @@ def wrap(eig): def test_as_wrap(self): from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) bz = MonkhorstPack(H, [2, 2, 2], trs=False) - assert len(bz) == 2 ** 3 + assert len(bz) == 2**3 # Check with a wrap function def wrap(arg): @@ -556,15 +578,17 @@ def wrap(arg): def test_as_wrap_default_oplist(self): from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) bz = MonkhorstPack(H, [2, 2, 2], trs=False) - assert len(bz) == 2 ** 3 + assert len(bz) == 2**3 # Check with a wrap function E = np.linspace(-2, 2, 20) + def wrap_sum(es, weight): PDOS = es.PDOS(E)[0] * weight return PDOS.sum(0), PDOS @@ -576,6 +600,7 @@ def wrap_sum(es, weight): def test_wrap_unzip(self): from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) @@ -584,6 +609,7 @@ def test_wrap_unzip(self): # Check with a wrap function E = np.linspace(-2, 2, 20) + def wrap(es): return es.eig, es.DOS(E) @@ -592,8 +618,8 @@ def wrap(es): eig1, DOS1 = k_unzip.list.eigenstate(wrap=wrap) eig2, DOS2 = k_unzip.array.eigenstate(wrap=wrap) - #eig0 and DOS0 are generators, and not list's - #eig1 and DOS1 are generators, and not list's + # eig0 and DOS0 are generators, and not list's + # eig1 and DOS1 are generators, and not list's assert isinstance(eig2, np.ndarray) assert isinstance(DOS2, np.ndarray) @@ -605,28 +631,49 @@ def wrap(es): # Check with a wrap function and the weight argument def test_wrap_kwargs(arg): from sisl import Hamiltonian, geom + g = geom.graphene() H = Hamiltonian(g) H.construct([[0.1, 1.44], [0, -2.7]]) bz = MonkhorstPack(H, [2, 2, 2], trs=False) - assert len(bz) == 2 ** 3 + assert len(bz) == 2**3 def wrap_none(arg): return arg + def wrap_kwargs(arg, parent, k, weight): return arg * weight E = np.linspace(-2, 2, 20) bz_array = bz.apply.array - asarray1 = (bz_array.eigenstate(wrap=lambda es: es.DOS(E)) * bz.weight.reshape(-1, 1)).sum(0) - asarray2 = bz_array.eigenstate(wrap=lambda es, parent, k, weight: es.DOS(E)).sum(0) + asarray1 = ( + bz_array.eigenstate(wrap=lambda es: es.DOS(E)) * bz.weight.reshape(-1, 1) + ).sum(0) + asarray2 = bz_array.eigenstate( + wrap=lambda es, parent, k, weight: es.DOS(E) + ).sum(0) bz_list = bz.apply.list - aslist1 = (np.array(bz_list.eigenstate(wrap=lambda es: es.DOS(E))) * bz.weight.reshape(-1, 1)).sum(0) - aslist2 = np.array(bz_list.eigenstate(wrap=lambda es, parent, k, weight: es.DOS(E))).sum(0) + aslist1 = ( + np.array(bz_list.eigenstate(wrap=lambda es: es.DOS(E))) + * bz.weight.reshape(-1, 1) + ).sum(0) + aslist2 = np.array( + bz_list.eigenstate(wrap=lambda es, parent, k, weight: es.DOS(E)) + ).sum(0) bz_yield = bz.apply.iter - asyield1 = (np.array([a for a in bz_yield.eigenstate(wrap=lambda es: es.DOS(E))]) * bz.weight.reshape(-1, 1)).sum(0) - asyield2 = np.array([a for a in bz_yield.eigenstate(wrap=lambda es, parent, k, weight: es.DOS(E))]).sum(0) + asyield1 = ( + np.array([a for a in bz_yield.eigenstate(wrap=lambda es: es.DOS(E))]) + * bz.weight.reshape(-1, 1) + ).sum(0) + asyield2 = np.array( + [ + a + for a in bz_yield.eigenstate( + wrap=lambda es, parent, k, weight: es.DOS(E) + ) + ] + ).sum(0) asaverage = bz.apply.average.eigenstate(wrap=lambda es: es.DOS(E)) assum = bz.apply.sum.eigenstate(wrap=lambda es: es.DOS(E)) @@ -643,10 +690,10 @@ def test_replace_gamma(self): g = geom.graphene() bz = MonkhorstPack(g, 2, trs=False) bz_gamma = MonkhorstPack(g, [2, 2, 2], size=[0.5] * 3, trs=False) - assert len(bz) == 2 ** 3 + assert len(bz) == 2**3 bz.replace([0] * 3, bz_gamma) - assert len(bz) == 2 ** 3 + 2 ** 3 - 1 - assert bz.weight.sum() == pytest.approx(1.) + assert len(bz) == 2**3 + 2**3 - 1 + assert bz.weight.sum() == pytest.approx(1.0) assert np.allclose(bz.copy().k, bz.k) assert np.allclose(bz.copy().weight, bz.weight) @@ -658,13 +705,19 @@ def test_replace_gamma_trs(self): N_bz_gamma = len(bz_gamma) bz.replace([0] * 3, bz_gamma) assert len(bz) == N_bz + N_bz_gamma - 1 - assert bz.weight.sum() == pytest.approx(1.) + assert bz.weight.sum() == pytest.approx(1.0) def test_replace_trs_neg(self): g = geom.graphene() bz_big = MonkhorstPack(g, [6, 6, 1], trs=True) N_bz_big = len(bz_big) - bz_small = MonkhorstPack(g, [3, 3, 3], size=[1/6, 1/6, 1], displacement=[2/3, 1/3, 0], trs=True) + bz_small = MonkhorstPack( + g, + [3, 3, 3], + size=[1 / 6, 1 / 6, 1], + displacement=[2 / 3, 1 / 3, 0], + trs=True, + ) N_bz_small = len(bz_small) # it should be the same for both negative|positive displ @@ -675,27 +728,26 @@ def test_replace_trs_neg(self): bz_neg.replace(-bz_small.displacement, bz_small) for bz in [bz_pos, bz_neg]: assert len(bz) == N_bz_big + N_bz_small - 1 - assert bz.weight.sum() == pytest.approx(1.) + assert bz.weight.sum() == pytest.approx(1.0) def test_in_primitive(self): - assert np.allclose(MonkhorstPack.in_primitive([[1.] * 3, [-1.] * 3]), 0) + assert np.allclose(MonkhorstPack.in_primitive([[1.0] * 3, [-1.0] * 3]), 0) @pytest.mark.bandstructure class TestBandStructure: - def setUp(self, setup): setup.s1 = Lattice(1, nsc=[3, 3, 1]) setup.s2 = Lattice([2, 2, 10, 90, 90, 60], [5, 5, 1]) def test_pbz1(self, setup): - bz = BandStructure(setup.s1, [[0]*3, [.5]*3], 300) + bz = BandStructure(setup.s1, [[0] * 3, [0.5] * 3], 300) assert len(bz) == 300 - bz2 = BandStructure(setup.s1, [[0]*2, [.5]*2], 300, ['A', 'C']) + bz2 = BandStructure(setup.s1, [[0] * 2, [0.5] * 2], 300, ["A", "C"]) assert len(bz) == 300 - bz3 = BandStructure(setup.s1, [[0]*2, [.5]*2], [150]) + bz3 = BandStructure(setup.s1, [[0] * 2, [0.5] * 2], [150]) assert len(bz) == 300 bz.lineartick() bz.lineark() @@ -703,11 +755,11 @@ def test_pbz1(self, setup): @pytest.mark.parametrize("n", range(3, 100, 10)) def test_pbz2(self, setup, n): - bz = BandStructure(setup.s1, [[0]*3, [.25]*3, [.5]*3], n) + bz = BandStructure(setup.s1, [[0] * 3, [0.25] * 3, [0.5] * 3], n) assert len(bz) == n def test_pbs_divisions(self, setup): - bz = BandStructure(setup.s1, [[0]*3, [.25]*3, [.5]*3], [10, 10]) + bz = BandStructure(setup.s1, [[0] * 3, [0.25] * 3, [0.5] * 3], [10, 10]) assert len(bz) == 21 def test_pbs_missing_arguments(self, setup): @@ -716,15 +768,16 @@ def test_pbs_missing_arguments(self, setup): def test_pbs_fail(self, setup): with pytest.raises(ValueError): - BandStructure(setup.s1, [[0]*3, [.5]*3, [.25] * 3], 1) + BandStructure(setup.s1, [[0] * 3, [0.5] * 3, [0.25] * 3], 1) with pytest.raises(ValueError): - BandStructure(setup.s1, [[0]*3, [.5]*3, [.25] * 3], [1, 1, 1, 1]) + BandStructure(setup.s1, [[0] * 3, [0.5] * 3, [0.25] * 3], [1, 1, 1, 1]) with pytest.raises(ValueError): - BandStructure(setup.s1, [[0]*3, [.5]*3, [.25] * 3], [1, 1, 1]) + BandStructure(setup.s1, [[0] * 3, [0.5] * 3, [0.25] * 3], [1, 1, 1]) def test_pickle(self, setup): import pickle as p - bz1 = BandStructure(setup.s1, [[0]*2, [.5]*2], 300, ['A', 'C']) + + bz1 = BandStructure(setup.s1, [[0] * 2, [0.5] * 2], 300, ["A", "C"]) n = p.dumps(bz1) bz2 = p.loads(n) assert np.allclose(bz1.k, bz2.k) @@ -736,37 +789,61 @@ def test_pickle(self, setup): def test_jump(self): g = geom.graphene() - bs = BandStructure(g, [[0]*3, [0.5, 0, 0], None, [0]*3, [0., 0.5, 0]], 30, ['A', 'B', 'C', 'D']) + bs = BandStructure( + g, + [[0] * 3, [0.5, 0, 0], None, [0] * 3, [0.0, 0.5, 0]], + 30, + ["A", "B", "C", "D"], + ) assert len(bs) == 30 def test_jump_skipping_none(self): g = geom.graphene() - bs1 = BandStructure(g, [[0]*3, [0.5, 0, 0], None, [0]*3, [0., 0.5, 0]], 30, ['A', 'B', 'C', 'D']) - bs2 = BandStructure(g, [[0]*3, [0.5, 0, 0], None, [0]*3, [0., 0.5, 0], None], 30, ['A', 'B', 'C', 'D']) + bs1 = BandStructure( + g, + [[0] * 3, [0.5, 0, 0], None, [0] * 3, [0.0, 0.5, 0]], + 30, + ["A", "B", "C", "D"], + ) + bs2 = BandStructure( + g, + [[0] * 3, [0.5, 0, 0], None, [0] * 3, [0.0, 0.5, 0], None], + 30, + ["A", "B", "C", "D"], + ) assert np.allclose(bs1.k, bs2.k) def test_insert_jump(self): g = geom.graphene() nk = 10 - bs = BandStructure(g, [[0]*3, [0.5, 0, 0], None, [0]*3, None, [0., 0.5, 0]], nk, ['A', 'B', 'C', 'D']) + bs = BandStructure( + g, + [[0] * 3, [0.5, 0, 0], None, [0] * 3, None, [0.0, 0.5, 0]], + nk, + ["A", "B", "C", "D"], + ) d = np.empty([nk]) d_jump = bs.insert_jump(d) - assert d_jump.shape == (nk+2,) + assert d_jump.shape == (nk + 2,) d = np.empty([nk, 5]) d_jump = bs.insert_jump(d) - assert d_jump.shape == (nk+2, 5) + assert d_jump.shape == (nk + 2, 5) assert np.isnan(d_jump).sum() == 10 d_jump = bs.insert_jump(d.T, value=np.inf) - assert d_jump.shape == (5, nk+2) + assert d_jump.shape == (5, nk + 2) assert np.isinf(d_jump).sum() == 10 def test_insert_jump_fail(self): g = geom.graphene() nk = 10 - bs = BandStructure(g, [[0]*3, [0.5, 0, 0], None, [0]*3, [0., 0.5, 0]], nk, ['A', 'B', 'C', 'D']) - d = np.empty([nk+1]) + bs = BandStructure( + g, + [[0] * 3, [0.5, 0, 0], None, [0] * 3, [0.0, 0.5, 0]], + nk, + ["A", "B", "C", "D"], + ) + d = np.empty([nk + 1]) with pytest.raises(ValueError): bs.insert_jump(d) - diff --git a/src/sisl/physics/tests/test_density_matrix.py b/src/sisl/physics/tests/test_density_matrix.py index 5c54e9ab2f..2bbd0bd67f 100644 --- a/src/sisl/physics/tests/test_density_matrix.py +++ b/src/sisl/physics/tests/test_density_matrix.py @@ -20,65 +20,73 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): self.bond = bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) n = 60 rf = np.linspace(0, bond * 1.01, n) rf = (rf, rf) - orb = SphericalOrbital(1, rf, 2.) + orb = SphericalOrbital(1, rf, 2.0) C = Atom(6, orb.toAtomicOrbital()) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.D = DensityMatrix(self.g) self.DS = DensityMatrix(self.g, orthogonal=False) def func(D, ia, atoms, atoms_xyz): - idx = D.geometry.close(ia, R=(0.1, 1.44), atoms=atoms, atoms_xyz=atoms_xyz) + idx = D.geometry.close( + ia, R=(0.1, 1.44), atoms=atoms, atoms_xyz=atoms_xyz + ) ia = ia * 3 i0 = idx[0] * 3 i1 = idx[1] * 3 # on-site - p = 1. + p = 1.0 D.D[ia, i0] = p - D.D[ia+1, i0+1] = p - D.D[ia+2, i0+2] = p + D.D[ia + 1, i0 + 1] = p + D.D[ia + 2, i0 + 2] = p # nn p = 0.1 # on-site directions - D.D[ia, ia+1] = p - D.D[ia, ia+2] = p - D.D[ia+1, ia] = p - D.D[ia+1, ia+2] = p - D.D[ia+2, ia] = p - D.D[ia+2, ia+1] = p + D.D[ia, ia + 1] = p + D.D[ia, ia + 2] = p + D.D[ia + 1, ia] = p + D.D[ia + 1, ia + 2] = p + D.D[ia + 2, ia] = p + D.D[ia + 2, ia + 1] = p - D.D[ia, i1+1] = p - D.D[ia, i1+2] = p + D.D[ia, i1 + 1] = p + D.D[ia, i1 + 2] = p - D.D[ia+1, i1] = p - D.D[ia+1, i1+2] = p + D.D[ia + 1, i1] = p + D.D[ia + 1, i1 + 2] = p - D.D[ia+2, i1] = p - D.D[ia+2, i1+1] = p + D.D[ia + 2, i1] = p + D.D[ia + 2, i1 + 1] = p self.func = func + return t() @pytest.mark.physics @pytest.mark.density_matrix class TestDensityMatrix: - def test_objects(self, setup): assert len(setup.D.xyz) == 2 assert setup.g.no == len(setup.D) @@ -91,58 +99,64 @@ def test_ortho(self, setup): def test_set1(self, setup): D = setup.D.copy() - D.D[0, 0] = 1. - assert D[0, 0] == 1. - assert D[1, 0] == 0. + D.D[0, 0] = 1.0 + assert D[0, 0] == 1.0 + assert D[1, 0] == 0.0 def test_mulliken(self, setup): D = setup.D.copy() D.construct(setup.func) - mulliken = D.mulliken('atom') + mulliken = D.mulliken("atom") assert mulliken.shape == (len(D.geometry),) - mulliken = D.mulliken('orbital') + mulliken = D.mulliken("orbital") assert mulliken.shape == (len(D),) def test_mulliken_values_orthogonal(self, setup): D = setup.D.copy() - D.D[0, 0] = 1. - D.D[1, 1] = 2. - D.D[1, 2] = 2. - mulliken = D.mulliken('orbital') - assert np.allclose(mulliken[:2], [1., 2.]) + D.D[0, 0] = 1.0 + D.D[1, 1] = 2.0 + D.D[1, 2] = 2.0 + mulliken = D.mulliken("orbital") + assert np.allclose(mulliken[:2], [1.0, 2.0]) assert mulliken.sum() == pytest.approx(3) - mulliken = D.mulliken('atom') + mulliken = D.mulliken("atom") assert mulliken[0] == pytest.approx(3) assert mulliken.sum() == pytest.approx(3) def test_mulliken_values_non_orthogonal(self, setup): D = setup.DS.copy() - D[0, 0] = (1., 1.) - D[1, 1] = (2., 1.) - D[1, 2] = (2., 0.5) - mulliken = D.mulliken('orbital') - assert np.allclose(mulliken[:2], [1., 3.]) + D[0, 0] = (1.0, 1.0) + D[1, 1] = (2.0, 1.0) + D[1, 2] = (2.0, 0.5) + mulliken = D.mulliken("orbital") + assert np.allclose(mulliken[:2], [1.0, 3.0]) assert mulliken.sum() == pytest.approx(4) - mulliken = D.mulliken('atom') + mulliken = D.mulliken("atom") assert mulliken[0] == pytest.approx(4) assert mulliken.sum() == pytest.approx(4) def test_mulliken_polarized(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('P')) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("P")) # 1 charge onsite for each spin-up # 0.5 charge onsite for each spin-down - D.construct([[0.1, bond + 0.01], [(1., 0.5), (0.1, 0.1)]]) + D.construct([[0.1, bond + 0.01], [(1.0, 0.5), (0.1, 0.1)]]) m = D.mulliken("orbital") assert m[0].sum() == pytest.approx(3) @@ -160,43 +174,59 @@ def test_rho1(self, setup): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_rho2(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) n = 60 rf = np.linspace(0, bond * 1.01, n) rf = (rf, rf) - orb = SphericalOrbital(1, rf, 2.) + orb = SphericalOrbital(1, rf, 2.0) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) D = DensityMatrix(g) - D.construct([[0.1, bond + 0.01], [1., 0.1]]) + D.construct([[0.1, bond + 0.01], [1.0, 0.1]]) grid = Grid(0.2, geometry=D.geometry) D.density(grid) - D = DensityMatrix(g, spin=Spin('P')) - D.construct([[0.1, bond + 0.01], [(1., 0.5), (0.1, 0.1)]]) + D = DensityMatrix(g, spin=Spin("P")) + D.construct([[0.1, bond + 0.01], [(1.0, 0.5), (0.1, 0.1)]]) grid = Grid(0.2, geometry=D.geometry) D.density(grid) - D.density(grid, [1., -1]) + D.density(grid, [1.0, -1]) D.density(grid, 0) D.density(grid, 1) - D = DensityMatrix(g, spin=Spin('NC')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01), (0.1, 0.1, 0.1, 0.1)]]) + D = DensityMatrix(g, spin=Spin("NC")) + D.construct( + [[0.1, bond + 0.01], [(1.0, 0.5, 0.01, 0.01), (0.1, 0.1, 0.1, 0.1)]] + ) grid = Grid(0.2, geometry=D.geometry) D.density(grid) - D.density(grid, [[1., 0.], [0., -1]]) - - D = DensityMatrix(g, spin=Spin('SO')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01, 0.01, 0.01, 0., 0.), (0.1, 0.1, 0.1, 0.1, 0., 0., 0., 0.)]]) + D.density(grid, [[1.0, 0.0], [0.0, -1]]) + + D = DensityMatrix(g, spin=Spin("SO")) + D.construct( + [ + [0.1, bond + 0.01], + [ + (1.0, 0.5, 0.01, 0.01, 0.01, 0.01, 0.0, 0.0), + (0.1, 0.1, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0), + ], + ] + ) grid = Grid(0.2, geometry=D.geometry) D.density(grid) - D.density(grid, [[1., 0.], [0., -1]]) + D.density(grid, [[1.0, 0.0], [0.0, -1]]) D.density(grid, Spin.X) D.density(grid, Spin.Y) D.density(grid, Spin.Z) @@ -204,35 +234,55 @@ def test_rho2(self): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_orbital_momentum(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('SO')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01, 0.01, 0.01, 0., 0.), (0.1, 0.1, 0.1, 0.1, 0., 0., 0., 0.)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("SO")) + D.construct( + [ + [0.1, bond + 0.01], + [ + (1.0, 0.5, 0.01, 0.01, 0.01, 0.01, 0.0, 0.0), + (0.1, 0.1, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0), + ], + ] + ) D.orbital_momentum("atom") D.orbital_momentum("orbital") def test_spin_align_pol(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('p')) - D.construct([[0.1, bond + 0.01], [(1., 0.5), (0.1, 0.2)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("p")) + D.construct([[0.1, bond + 0.01], [(1.0, 0.5), (0.1, 0.2)]]) D_mull = D.mulliken() assert D_mull.shape == (2, len(D)) @@ -246,18 +296,26 @@ def test_spin_align_pol(self): def test_spin_align_nc(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('nc')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01), (0.1, 0.2, 0.1, 0.1)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("nc")) + D.construct( + [[0.1, bond + 0.01], [(1.0, 0.5, 0.01, 0.01), (0.1, 0.2, 0.1, 0.1)]] + ) D_mull = D.mulliken() v = np.array([1, 2, 3]) d = D.spin_align(v) @@ -268,18 +326,32 @@ def test_spin_align_nc(self): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_spin_align_so(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('SO')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01, 0.01, 0.01, 0.2, 0.2), (0.1, 0.2, 0.1, 0.1, 0., 0.1, 0.2, 0.3)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("SO")) + D.construct( + [ + [0.1, bond + 0.01], + [ + (1.0, 0.5, 0.01, 0.01, 0.01, 0.01, 0.2, 0.2), + (0.1, 0.2, 0.1, 0.1, 0.0, 0.1, 0.2, 0.3), + ], + ] + ) D_mull = D.mulliken() v = np.array([1, 2, 3]) d = D.spin_align(v) @@ -289,18 +361,24 @@ def test_spin_align_so(self): def test_spin_rotate_pol(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('p')) - D.construct([[0.1, bond + 0.01], [(1., 0.5), (0.1, 0.2)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("p")) + D.construct([[0.1, bond + 0.01], [(1.0, 0.5), (0.1, 0.2)]]) D_mull = D.mulliken() assert D_mull.shape == (2, len(D)) @@ -314,18 +392,26 @@ def test_spin_rotate_pol(self): def test_spin_rotate_nc(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('nc')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01), (0.1, 0.2, 0.1, 0.1)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("nc")) + D.construct( + [[0.1, bond + 0.01], [(1.0, 0.5, 0.01, 0.01), (0.1, 0.2, 0.1, 0.1)]] + ) D_mull = D.mulliken() d = D.spin_rotate([45, 60, 90], rad=False) @@ -338,18 +424,32 @@ def test_spin_rotate_nc(self): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_spin_rotate_so(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - - orb = AtomicOrbital('px', R=bond * 1.001) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + + orb = AtomicOrbital("px", R=bond * 1.001) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - D = DensityMatrix(g, spin=Spin('SO')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01, 0.01, 0.01, 0.2, 0.2), (0.1, 0.2, 0.1, 0.1, 0., 0.1, 0.2, 0.3)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + D = DensityMatrix(g, spin=Spin("SO")) + D.construct( + [ + [0.1, bond + 0.01], + [ + (1.0, 0.5, 0.01, 0.01, 0.01, 0.01, 0.2, 0.2), + (0.1, 0.2, 0.1, 0.1, 0.0, 0.1, 0.2, 0.3), + ], + ] + ) D_mull = D.mulliken() d = D.spin_rotate([45, 60, 90], rad=False) d_mull = d.mulliken() @@ -371,50 +471,65 @@ def test_rho_smaller_grid1(self, setup): def test_rho_fail_p(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) n = 60 rf = np.linspace(0, bond * 1.01, n) rf = (rf, rf) - orb = SphericalOrbital(1, rf, 2.) + orb = SphericalOrbital(1, rf, 2.0) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - - D = DensityMatrix(g, spin=Spin('P')) - D.construct([[0.1, bond + 0.01], [(1., 0.5), (0.1, 0.1)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + + D = DensityMatrix(g, spin=Spin("P")) + D.construct([[0.1, bond + 0.01], [(1.0, 0.5), (0.1, 0.1)]]) grid = Grid(0.2, geometry=D.geometry) with pytest.raises(ValueError): - D.density(grid, [1., -1, 0.]) + D.density(grid, [1.0, -1, 0.0]) def test_rho_fail_nc(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) n = 60 rf = np.linspace(0, bond * 1.01, n) rf = (rf, rf) - orb = SphericalOrbital(1, rf, 2.) + orb = SphericalOrbital(1, rf, 2.0) C = Atom(6, orb) - g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=lattice) - - D = DensityMatrix(g, spin=Spin('NC')) - D.construct([[0.1, bond + 0.01], [(1., 0.5, 0.01, 0.01), (0.1, 0.1, 0.1, 0.1)]]) + g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=lattice, + ) + + D = DensityMatrix(g, spin=Spin("NC")) + D.construct( + [[0.1, bond + 0.01], [(1.0, 0.5, 0.01, 0.01), (0.1, 0.1, 0.1, 0.1)]] + ) grid = Grid(0.2, geometry=D.geometry) with pytest.raises(ValueError): - D.density(grid, [1., 0.]) + D.density(grid, [1.0, 0.0]) def test_pickle(self, setup): import pickle as p + D = setup.D.copy() D.construct(setup.func) s = p.dumps(D) @@ -423,21 +538,21 @@ def test_pickle(self, setup): assert np.allclose(d.eigh(), D.eigh()) def test_transform(self, setup): - D = DensityMatrix(setup.g, spin='so') + D = DensityMatrix(setup.g, spin="so") a = np.arange(8) for ia in setup.g: D[ia, ia] = a Dcsr = [D.tocsr(i) for i in range(D.shape[2])] - Dt = D.transform(spin='unpolarized', dtype=np.float32) + Dt = D.transform(spin="unpolarized", dtype=np.float32) assert np.abs(0.5 * Dcsr[0] + 0.5 * Dcsr[1] - Dt.tocsr(0)).sum() == 0 - Dt = D.transform(spin='polarized', orthogonal=False) + Dt = D.transform(spin="polarized", orthogonal=False) assert np.abs(Dcsr[0] - Dt.tocsr(0)).sum() == 0 assert np.abs(Dcsr[1] - Dt.tocsr(1)).sum() == 0 assert np.abs(Dt.tocsr(2)).sum() != 0 - Dt = D.transform(spin='non-colinear', orthogonal=False) + Dt = D.transform(spin="non-colinear", orthogonal=False) assert np.abs(Dcsr[0] - Dt.tocsr(0)).sum() == 0 assert np.abs(Dcsr[1] - Dt.tocsr(1)).sum() == 0 assert np.abs(Dcsr[2] - Dt.tocsr(2)).sum() == 0 @@ -445,28 +560,28 @@ def test_transform(self, setup): assert np.abs(Dt.tocsr(-1)).sum() != 0 def test_transform_nonortho(self, setup): - D = DensityMatrix(setup.g, spin='polarized', orthogonal=False) + D = DensityMatrix(setup.g, spin="polarized", orthogonal=False) a = np.arange(3) - a[-1] = 1. + a[-1] = 1.0 for ia in setup.g: D[ia, ia] = a - Dt = D.transform(spin='unpolarized', dtype=np.float32) + Dt = D.transform(spin="unpolarized", dtype=np.float32) assert np.abs(0.5 * D.tocsr(0) + 0.5 * D.tocsr(1) - Dt.tocsr(0)).sum() == 0 assert np.abs(D.tocsr(-1) - Dt.tocsr(-1)).sum() == 0 - Dt = D.transform(spin='polarized') + Dt = D.transform(spin="polarized") assert np.abs(D.tocsr(0) - Dt.tocsr(0)).sum() == 0 assert np.abs(D.tocsr(1) - Dt.tocsr(1)).sum() == 0 - Dt = D.transform(spin='polarized', orthogonal=True) + Dt = D.transform(spin="polarized", orthogonal=True) assert np.abs(D.tocsr(0) - Dt.tocsr(0)).sum() == 0 assert np.abs(D.tocsr(1) - Dt.tocsr(1)).sum() == 0 - Dt = D.transform(spin='non-colinear', orthogonal=False) + Dt = D.transform(spin="non-colinear", orthogonal=False) assert np.abs(D.tocsr(0) - Dt.tocsr(0)).sum() == 0 assert np.abs(D.tocsr(1) - Dt.tocsr(1)).sum() == 0 assert np.abs(Dt.tocsr(2)).sum() == 0 assert np.abs(Dt.tocsr(3)).sum() == 0 assert np.abs(D.tocsr(-1) - Dt.tocsr(-1)).sum() == 0 - Dt = D.transform(spin='so', orthogonal=True) + Dt = D.transform(spin="so", orthogonal=True) assert np.abs(D.tocsr(0) - Dt.tocsr(0)).sum() == 0 assert np.abs(D.tocsr(1) - Dt.tocsr(1)).sum() == 0 assert np.abs(Dt.tocsr(-1)).sum() == 0 diff --git a/src/sisl/physics/tests/test_distribution.py b/src/sisl/physics/tests/test_distribution.py index e4449dcb02..f05df43b23 100644 --- a/src/sisl/physics/tests/test_distribution.py +++ b/src/sisl/physics/tests/test_distribution.py @@ -14,39 +14,39 @@ def test_distribution1(): x = np.linspace(-2, 2, 10000) dx = x[1] - x[0] - d = get_distribution('gaussian', smearing=0.025) + d = get_distribution("gaussian", smearing=0.025) assert d(x).sum() * dx == pytest.approx(1, abs=1e-6) - d = get_distribution('lorentzian', smearing=1e-3) + d = get_distribution("lorentzian", smearing=1e-3) assert d(x).sum() * dx == pytest.approx(1, abs=1e-3) def test_distribution2(): x = np.linspace(-2, 2, 10000) - d = get_distribution('gaussian', smearing=0.025) + d = get_distribution("gaussian", smearing=0.025) assert np.allclose(d(x), gaussian(x, 0.025)) - d = get_distribution('lorentzian', smearing=1e-3) + d = get_distribution("lorentzian", smearing=1e-3) assert np.allclose(d(x), lorentzian(x, 1e-3)) - d = get_distribution('step', smearing=1e-3) + d = get_distribution("step", smearing=1e-3) assert np.allclose(d(x), step_function(x)) - d = get_distribution('heaviside', smearing=1e-3, x0=1) + d = get_distribution("heaviside", smearing=1e-3, x0=1) assert np.allclose(d(x), 1 - step_function(x, 1)) - d = get_distribution('heaviside', x0=-0.5) + d = get_distribution("heaviside", x0=-0.5) assert np.allclose(d(x), heaviside(x, -0.5)) def test_distribution3(): with pytest.raises(ValueError): - get_distribution('unknown-function') + get_distribution("unknown-function") def test_distribution_x0(): x1 = np.linspace(-2, 2, 10000) x2 = np.linspace(-3, 1, 10000) - d1 = get_distribution('gaussian') - d2 = get_distribution('gaussian', x0=-1) + d1 = get_distribution("gaussian") + d2 = get_distribution("gaussian", x0=-1) assert np.allclose(d1(x1), d2(x2)) - d1 = get_distribution('lorentzian') - d2 = get_distribution('lorentzian', x0=-1) + d1 = get_distribution("lorentzian") + d2 = get_distribution("lorentzian", x0=-1) assert np.allclose(d1(x1), d2(x2)) diff --git a/src/sisl/physics/tests/test_dynamical_matrix.py b/src/sisl/physics/tests/test_dynamical_matrix.py index 125fcef7fb..de89a81c8d 100644 --- a/src/sisl/physics/tests/test_dynamical_matrix.py +++ b/src/sisl/physics/tests/test_dynamical_matrix.py @@ -13,58 +13,66 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) C = Atom(Z=6, R=[bond * 1.01] * 3) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.D = DynamicalMatrix(self.g) def func(D, ia, idxs, idxs_xyz): - idx = D.geometry.close(ia, R=(0.1, 1.44), atoms=idxs, atoms_xyz=idxs_xyz) + idx = D.geometry.close( + ia, R=(0.1, 1.44), atoms=idxs, atoms_xyz=idxs_xyz + ) ia = ia * 3 i0 = idx[0] * 3 i1 = idx[1] * 3 # on-site - p = 1. + p = 1.0 D.D[ia, i0] = p - D.D[ia+1, i0+1] = p - D.D[ia+2, i0+2] = p + D.D[ia + 1, i0 + 1] = p + D.D[ia + 2, i0 + 2] = p # nn p = 0.1 # on-site directions - D.D[ia, ia+1] = p - D.D[ia, ia+2] = p - D.D[ia+1, ia] = p - D.D[ia+1, ia+2] = p - D.D[ia+2, ia] = p - D.D[ia+2, ia+1] = p + D.D[ia, ia + 1] = p + D.D[ia, ia + 2] = p + D.D[ia + 1, ia] = p + D.D[ia + 1, ia + 2] = p + D.D[ia + 2, ia] = p + D.D[ia + 2, ia + 1] = p - D.D[ia, i1+1] = p - D.D[ia, i1+2] = p + D.D[ia, i1 + 1] = p + D.D[ia, i1 + 2] = p - D.D[ia+1, i1] = p - D.D[ia+1, i1+2] = p + D.D[ia + 1, i1] = p + D.D[ia + 1, i1 + 2] = p - D.D[ia+2, i1] = p - D.D[ia+2, i1+1] = p + D.D[ia + 2, i1] = p + D.D[ia + 2, i1 + 1] = p self.func = func + return t() class TestDynamicalMatrix: - def test_objects(self, setup): assert len(setup.D.xyz) == 2 assert setup.g.no == len(setup.D) @@ -76,14 +84,14 @@ def test_ortho(self, setup): assert setup.D.orthogonal def test_set1(self, setup): - setup.D.D[0, 0] = 1. - assert setup.D[0, 0] == 1. - assert setup.D[1, 0] == 0. + setup.D.D[0, 0] = 1.0 + assert setup.D[0, 0] == 1.0 + assert setup.D[1, 0] == 0.0 setup.D.empty() def test_apply_newton(self, setup): setup.D.construct(setup.func) - assert setup.D[0, 0] == 1. + assert setup.D[0, 0] == 1.0 assert setup.D[1, 0] == 0.1 assert setup.D[0, 1] == 0.1 setup.D.apply_newton() @@ -102,21 +110,22 @@ def test_change_gauge(self, setup): D.construct(setup.func) em = D.eigenmode(k=(0.2, 0.2, 0.2)) em2 = em.copy() - em2.change_gauge('r') + em2.change_gauge("r") assert not np.allclose(em.mode, em2.mode) - em2.change_gauge('R') + em2.change_gauge("R") assert np.allclose(em.mode, em2.mode) - @pytest.mark.filterwarnings('ignore', category=np.ComplexWarning) + @pytest.mark.filterwarnings("ignore", category=np.ComplexWarning) def test_dos_pdos_velocity(self, setup): D = setup.D.copy() D.construct(setup.func) - E = np.linspace(0, .5, 10) + E = np.linspace(0, 0.5, 10) em = D.eigenmode() assert np.allclose(em.DOS(E), em.PDOS(E).sum(0)) def test_pickle(self, setup): import pickle as p + D = setup.D.copy() D.construct(setup.func) s = p.dumps(D) diff --git a/src/sisl/physics/tests/test_energy_density_matrix.py b/src/sisl/physics/tests/test_energy_density_matrix.py index 953524f1cc..7fcf6a9a8a 100644 --- a/src/sisl/physics/tests/test_energy_density_matrix.py +++ b/src/sisl/physics/tests/test_energy_density_matrix.py @@ -11,54 +11,63 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) C = Atom(Z=6, R=[bond * 1.01] * 3) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.E = EnergyDensityMatrix(self.g) self.ES = EnergyDensityMatrix(self.g, orthogonal=False) def func(E, ia, idxs, idxs_xyz): - idx = E.geometry.close(ia, R=(0.1, 1.44), atoms=idxs, atoms_xyz=idxs_xyz) + idx = E.geometry.close( + ia, R=(0.1, 1.44), atoms=idxs, atoms_xyz=idxs_xyz + ) ia = ia * 3 i0 = idx[0] * 3 i1 = idx[1] * 3 # on-site - p = 1. + p = 1.0 E.E[ia, i0] = p - E.E[ia+1, i0+1] = p - E.E[ia+2, i0+2] = p + E.E[ia + 1, i0 + 1] = p + E.E[ia + 2, i0 + 2] = p # nn p = 0.1 # on-site directions - E.E[ia, ia+1] = p - E.E[ia, ia+2] = p - E.E[ia+1, ia] = p - E.E[ia+1, ia+2] = p - E.E[ia+2, ia] = p - E.E[ia+2, ia+1] = p + E.E[ia, ia + 1] = p + E.E[ia, ia + 2] = p + E.E[ia + 1, ia] = p + E.E[ia + 1, ia + 2] = p + E.E[ia + 2, ia] = p + E.E[ia + 2, ia + 1] = p - E.E[ia, i1+1] = p - E.E[ia, i1+2] = p + E.E[ia, i1 + 1] = p + E.E[ia, i1 + 2] = p - E.E[ia+1, i1] = p - E.E[ia+1, i1+2] = p + E.E[ia + 1, i1] = p + E.E[ia + 1, i1 + 2] = p - E.E[ia+2, i1] = p - E.E[ia+2, i1+1] = p + E.E[ia + 2, i1] = p + E.E[ia + 2, i1 + 1] = p self.func = func + return t() @@ -66,7 +75,6 @@ def func(E, ia, idxs, idxs_xyz): @pytest.mark.density_matrix @pytest.mark.energydensity_matrix class TestEnergyDensityMatrix: - def test_objects(self, setup): assert len(setup.E.xyz) == 2 assert setup.g.no == len(setup.E) @@ -74,9 +82,9 @@ def test_objects(self, setup): def test_spin(self, setup): g = setup.g.copy() EnergyDensityMatrix(g) - EnergyDensityMatrix(g, spin=Spin('P')) - EnergyDensityMatrix(g, spin=Spin('NC')) - EnergyDensityMatrix(g, spin=Spin('SO')) + EnergyDensityMatrix(g, spin=Spin("P")) + EnergyDensityMatrix(g, spin=Spin("NC")) + EnergyDensityMatrix(g, spin=Spin("SO")) def test_dtype(self, setup): assert setup.E.dtype == np.float64 @@ -87,43 +95,44 @@ def test_ortho(self, setup): def test_mulliken(self, setup): E = setup.E.copy() E.construct(setup.func) - mulliken = E.mulliken('atom') + mulliken = E.mulliken("atom") assert mulliken.shape == (len(E.geometry),) - mulliken = E.mulliken('orbital') + mulliken = E.mulliken("orbital") assert mulliken.shape == (len(E),) def test_mulliken_values_orthogonal(self, setup): E = setup.E.copy() - E[0, 0] = 1. - E[1, 1] = 2. - E[1, 2] = 2. - mulliken = E.mulliken('orbital') - assert np.allclose(mulliken[:2], [1., 2.]) + E[0, 0] = 1.0 + E[1, 1] = 2.0 + E[1, 2] = 2.0 + mulliken = E.mulliken("orbital") + assert np.allclose(mulliken[:2], [1.0, 2.0]) assert mulliken.sum() == pytest.approx(3) - mulliken = E.mulliken('atom') + mulliken = E.mulliken("atom") assert mulliken[0] == pytest.approx(3) assert mulliken.sum() == pytest.approx(3) def test_mulliken_values_non_orthogonal(self, setup): E = setup.ES.copy() - E[0, 0] = (1., 1.) - E[1, 1] = (2., 1.) - E[1, 2] = (2., 0.5) - mulliken = E.mulliken('orbital') - assert np.allclose(mulliken[:2], [1., 3.]) - assert mulliken.sum() == pytest.approx(4.) - mulliken = E.mulliken('atom') + E[0, 0] = (1.0, 1.0) + E[1, 1] = (2.0, 1.0) + E[1, 2] = (2.0, 0.5) + mulliken = E.mulliken("orbital") + assert np.allclose(mulliken[:2], [1.0, 3.0]) + assert mulliken.sum() == pytest.approx(4.0) + mulliken = E.mulliken("atom") assert mulliken[0] == pytest.approx(4) assert mulliken.sum() == pytest.approx(4) def test_set1(self, setup): E = setup.E.copy() - E.E[0, 0] = 1. - assert E[0, 0] == 1. - assert E[1, 0] == 0. + E.E[0, 0] = 1.0 + assert E[0, 0] == 1.0 + assert E[1, 0] == 0.0 def test_pickle(self, setup): import pickle as p + E = setup.E.copy() E.construct(setup.func) s = p.dumps(E) diff --git a/src/sisl/physics/tests/test_feature.py b/src/sisl/physics/tests/test_feature.py index 128c421552..8f6fd4533b 100644 --- a/src/sisl/physics/tests/test_feature.py +++ b/src/sisl/physics/tests/test_feature.py @@ -11,9 +11,13 @@ def test_yield_manifolds_eigenvalues(): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice([10, 1, 5.], nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice([10, 1, 5.0], nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64) - H.construct([(0.1, 1.5), (1., 0.1)]) + H.construct([(0.1, 1.5), (1.0, 0.1)]) all_manifolds = [] for manifold in yield_manifolds(H.eigh()): diff --git a/src/sisl/physics/tests/test_hamiltonian.py b/src/sisl/physics/tests/test_hamiltonian.py index 3d6f794c2f..2d2542162e 100644 --- a/src/sisl/physics/tests/test_hamiltonian.py +++ b/src/sisl/physics/tests/test_hamiltonian.py @@ -26,38 +26,49 @@ ) from sisl.physics.electron import berry_phase, conductivity, spin_squared -pytestmark = [pytest.mark.physics, pytest.mark.hamiltonian, - pytest.mark.filterwarnings("ignore", category=SparseEfficiencyWarning)] +pytestmark = [ + pytest.mark.physics, + pytest.mark.hamiltonian, + pytest.mark.filterwarnings("ignore", category=SparseEfficiencyWarning), +] @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) C = Atom(Z=6, R=[bond * 1.01]) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.H = Hamiltonian(self.g) self.HS = Hamiltonian(self.g, orthogonal=False) C = Atom(Z=6, R=[bond * 1.01] * 2) - self.g2 = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g2 = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.H2 = Hamiltonian(self.g2) self.HS2 = Hamiltonian(self.g2, orthogonal=False) + return t() class TestHamiltonian: - def test_objects(self, setup): assert len(setup.H.xyz) == 2 assert setup.g.no == len(setup.H) @@ -82,45 +93,45 @@ def test_ortho(self, setup): assert setup.HS.non_orthogonal def test_set1(self, setup): - setup.H.H[0, 0] = 1. - assert setup.H[0, 0] == 1. - assert setup.H[1, 0] == 0. + setup.H.H[0, 0] = 1.0 + assert setup.H[0, 0] == 1.0 + assert setup.H[1, 0] == 0.0 setup.H.empty() - setup.HS.H[0, 0] = 1. - assert setup.HS.H[0, 0] == 1. - assert setup.HS.H[1, 0] == 0. - assert setup.HS.S[0, 0] == 0. - assert setup.HS.S[1, 0] == 0. - setup.HS.S[0, 0] = 1. - assert setup.HS.H[0, 0] == 1. - assert setup.HS.H[1, 0] == 0. - assert setup.HS.S[0, 0] == 1. - assert setup.HS.S[1, 0] == 0. + setup.HS.H[0, 0] = 1.0 + assert setup.HS.H[0, 0] == 1.0 + assert setup.HS.H[1, 0] == 0.0 + assert setup.HS.S[0, 0] == 0.0 + assert setup.HS.S[1, 0] == 0.0 + setup.HS.S[0, 0] = 1.0 + assert setup.HS.H[0, 0] == 1.0 + assert setup.HS.H[1, 0] == 0.0 + assert setup.HS.S[0, 0] == 1.0 + assert setup.HS.S[1, 0] == 0.0 # delete before creating the same content setup.HS.empty() # THIS IS A CHECK FOR BACK_WARD COMPATIBILITY! with warnings.catch_warnings(): - warnings.simplefilter('ignore') - setup.HS[0, 0] = 1., 1. - assert setup.HS.H[0, 0] == 1. - assert setup.HS.S[0, 0] == 1. + warnings.simplefilter("ignore") + setup.HS[0, 0] = 1.0, 1.0 + assert setup.HS.H[0, 0] == 1.0 + assert setup.HS.S[0, 0] == 1.0 setup.HS.empty() def test_set2(self, setup): - setup.H.construct([(0.1, 1.5), (1., 0.1)]) - assert setup.H[0, 0] == 1. + setup.H.construct([(0.1, 1.5), (1.0, 0.1)]) + assert setup.H[0, 0] == 1.0 assert setup.H[1, 0] == 0.1 assert setup.H[0, 1] == 0.1 setup.H.empty() def test_set3(self, setup): - setup.HS.construct([(0.1, 1.5), ((1., 2.), (0.1, 0.2))]) - assert setup.HS.H[0, 0] == 1. - assert setup.HS.S[0, 0] == 2. - assert setup.HS.H[1, 1] == 1. - assert setup.HS.S[1, 1] == 2. + setup.HS.construct([(0.1, 1.5), ((1.0, 2.0), (0.1, 0.2))]) + assert setup.HS.H[0, 0] == 1.0 + assert setup.HS.S[0, 0] == 2.0 + assert setup.HS.H[1, 1] == 1.0 + assert setup.HS.S[1, 1] == 2.0 assert setup.HS.H[1, 0] == 0.1 assert setup.HS.H[0, 1] == 0.1 assert setup.HS.S[1, 0] == 0.2 @@ -132,10 +143,10 @@ def test_set4(self, setup): for ia in setup.H.geometry: # Find atoms close to 'ia' idx = setup.H.geometry.close(ia, R=(0.1, 1.5)) - setup.H[ia, idx[0]] = 1. + setup.H[ia, idx[0]] = 1.0 setup.H[ia, idx[1]] = 0.1 - assert setup.H.H[0, 0] == 1. - assert setup.H.H[1, 1] == 1. + assert setup.H.H[0, 0] == 1.0 + assert setup.H.H[1, 1] == 1.0 assert setup.H.H[1, 0] == 0.1 assert setup.H.H[0, 1] == 0.1 assert setup.H.nnz == len(setup.H) * 4 @@ -146,9 +157,9 @@ def test_set5(self, setup): # Test of HUGE construct g = setup.g.tile(10, 0).tile(10, 1).tile(10, 2) H = Hamiltonian(g) - H.construct([(0.1, 1.5), (1., 0.1)]) - assert H.H[0, 0] == 1. - assert H.H[1, 1] == 1. + H.construct([(0.1, 1.5), (1.0, 0.1)]) + assert H.H[0, 0] == 1.0 + assert H.H[1, 1] == 1.0 assert H.H[1, 0] == 0.1 assert H.H[0, 1] == 0.1 # This is graphene @@ -158,7 +169,7 @@ def test_set5(self, setup): del H def test_iter1(self, setup): - setup.HS.construct([(0.1, 1.5), ((1., 2.), (0.1, 0.2))]) + setup.HS.construct([(0.1, 1.5), ((1.0, 2.0), (0.1, 0.2))]) nnz = 0 for io, jo in setup.HS: nnz = nnz + 1 @@ -171,7 +182,7 @@ def test_iter1(self, setup): setup.HS.empty() def test_iter2(self, setup): - setup.HS.H[0, 0] = 1. + setup.HS.H[0, 0] = 1.0 nnz = 0 for io, jo in setup.HS: nnz = nnz + 1 @@ -179,10 +190,10 @@ def test_iter2(self, setup): assert nnz == 1 setup.HS.empty() - @pytest.mark.filterwarnings('ignore', category=np.ComplexWarning) + @pytest.mark.filterwarnings("ignore", category=np.ComplexWarning) def test_Hk1(self, setup): H = setup.HS.copy() - H.construct([(0.1, 1.5), ((1., 2.), (0.1, 0.2))]) + H.construct([(0.1, 1.5), ((1.0, 2.0), (0.1, 0.2))]) h = H.copy() assert h.Hk().dtype == np.float64 assert h.Sk().dtype == np.float64 @@ -197,7 +208,7 @@ def test_Hk1(self, setup): def test_Hk2(self, setup): H = setup.HS.copy() - H.construct([(0.1, 1.5), ((1., 2.), (0.1, 0.2))]) + H.construct([(0.1, 1.5), ((1.0, 2.0), (0.1, 0.2))]) h = H.copy() Hk = h.Hk(k=[0.15, 0.15, 0.15]) assert Hk.dtype == np.complex128 @@ -208,7 +219,7 @@ def test_Hk2(self, setup): @pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) def test_Hk5(self, setup, dtype): H = setup.H.copy() - H.construct([(0.1, 1.5), (1., 0.1)]) + H.construct([(0.1, 1.5), (1.0, 0.1)]) Hk = H.Hk(k=[0.15, 0.15, 0.15], dtype=dtype) assert Hk.dtype == dtype Sk = H.Sk(k=[0.15, 0.15, 0.15]) @@ -221,21 +232,27 @@ def test_Hk5(self, setup, dtype): @pytest.mark.parametrize("k", [[0, 0, 0], [0.15, 0.15, 0.15]]) def test_Hk_format(self, setup, k): H = setup.HS.copy() - H.construct([(0.1, 1.5), ((1., 2.), (0.1, 0.2))]) - csr = H.Hk(k, format='csr').toarray() - mat = H.Hk(k, format='matrix') - arr = H.Hk(k, format='array') - coo = H.Hk(k, format='coo').toarray() + H.construct([(0.1, 1.5), ((1.0, 2.0), (0.1, 0.2))]) + csr = H.Hk(k, format="csr").toarray() + mat = H.Hk(k, format="matrix") + arr = H.Hk(k, format="array") + coo = H.Hk(k, format="coo").toarray() assert np.allclose(csr, mat) assert np.allclose(csr, arr) assert np.allclose(csr, coo) @pytest.mark.parametrize("orthogonal", [True, False]) @pytest.mark.parametrize("gauge", ["R", "r"]) - @pytest.mark.parametrize("spin", ["unpolarized", "polarized", "non-collinear", "spin-orbit"]) + @pytest.mark.parametrize( + "spin", ["unpolarized", "polarized", "non-collinear", "spin-orbit"] + ) @pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) def test_format_sc(self, orthogonal, gauge, spin, dtype): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice([10, 1, 5.], nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice([10, 1, 5.0], nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, orthogonal=orthogonal, spin=Spin(spin)) nd = H._csr._D.shape[-1] # this will correctly account for the double size for NC/SOC @@ -243,7 +260,7 @@ def test_format_sc(self, orthogonal, gauge, spin, dtype): no_s = H.geometry.no_s for ia in g: idx = g.close(ia, R=(0.1, 1.01))[1] - H[ia, ia] = 1. + H[ia, ia] = 1.0 H[ia, idx] = np.random.rand(nd) if dtype == np.complex64: atol = 1e-6 @@ -259,24 +276,24 @@ def test_format_sc(self, orthogonal, gauge, spin, dtype): for k in [[0, 0, 0], [0.15, 0.1, 0.05]]: for attr, kwargs in [("Hk", {"gauge": gauge}), ("Sk", {})]: Mk = getattr(H, attr) - csr = Mk(k, format='csr', **kwargs, dtype=dtype) - sc_csr1 = Mk(k, format='sc:csr', **kwargs, dtype=dtype) - sc_csr2 = Mk(k, format='sc', **kwargs, dtype=dtype) - sc_mat = Mk(k, format='sc:array', **kwargs, dtype=dtype) + csr = Mk(k, format="csr", **kwargs, dtype=dtype) + sc_csr1 = Mk(k, format="sc:csr", **kwargs, dtype=dtype) + sc_csr2 = Mk(k, format="sc", **kwargs, dtype=dtype) + sc_mat = Mk(k, format="sc:array", **kwargs, dtype=dtype) mat = sc_mat.reshape(no, n_s, no).sum(1) assert sc_mat.shape == sc_csr1.shape assert allclose(csr.toarray(), mat) assert allclose(sc_csr1.toarray(), sc_csr2.toarray()) for isc in range(n_s): - csr -= sc_csr1[:, isc * no: (isc + 1) * no] - assert allclose(csr.toarray(), 0.) + csr -= sc_csr1[:, isc * no : (isc + 1) * no] + assert allclose(csr.toarray(), 0.0) def test_construct_raise_default(self, setup): # Test that construct fails with more than one # orbital with pytest.raises(ValueError): - setup.H2.construct([(0.1, 1.5), (1., 0.1)]) + setup.H2.construct([(0.1, 1.5), (1.0, 0.1)]) def test_getitem1(self, setup): H = setup.H @@ -306,7 +323,7 @@ def test_delitem1(self, setup): H.empty() def test_fromsp1(self, setup): - setup.H.construct([(0.1, 1.5), (1., 0.1)]) + setup.H.construct([(0.1, 1.5), (1.0, 0.1)]) csr = setup.H.tocsr(0) H = Hamiltonian.fromsp(setup.H.geometry, csr) assert H.spsame(setup.H) @@ -314,14 +331,14 @@ def test_fromsp1(self, setup): def test_fromsp2(self, setup): H = setup.H.copy() - H.construct([(0.1, 1.5), (1., 0.1)]) + H.construct([(0.1, 1.5), (1.0, 0.1)]) csr = H.tocsr(0) with pytest.raises(ValueError): Hamiltonian.fromsp(setup.H.geometry.tile(2, 0), csr) def test_fromsp3(self, setup): H = setup.HS.copy() - H.construct([(0.1, 1.5), ([1., 1.], [0.1, 0])]) + H.construct([(0.1, 1.5), ([1.0, 1.0], [0.1, 0])]) h = Hamiltonian.fromsp(H.geometry.copy(), H.tocsr(0), H.tocsr(1)) assert H.spsame(h) @@ -329,13 +346,13 @@ def test_op1(self, setup): g = Geometry([[i, 0, 0] for i in range(100)], Atom(6, R=1.01), lattice=[100]) H = Hamiltonian(g, dtype=np.int32) for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) H[0, j] = i # i+ H += 1 for jj in j: - assert H[0, jj] == i+1 + assert H[0, jj] == i + 1 assert H[1, jj] == 0 # i- @@ -347,7 +364,7 @@ def test_op1(self, setup): # i* H *= 2 for jj in j: - assert H[0, jj] == i*2 + assert H[0, jj] == i * 2 assert H[1, jj] == 0 # // @@ -366,34 +383,34 @@ def test_op2(self, setup): g = Geometry([[i, 0, 0] for i in range(100)], Atom(6, R=1.01), lattice=[100]) H = Hamiltonian(g, dtype=np.int32) for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) H[0, j] = i # + s = H + 1 for jj in j: - assert s[0, jj] == i+1 + assert s[0, jj] == i + 1 assert H[0, jj] == i assert s[1, jj] == 0 # - s = H - 1 for jj in j: - assert s[0, jj] == i-1 + assert s[0, jj] == i - 1 assert H[0, jj] == i assert s[1, jj] == 0 # - s = 1 - H for jj in j: - assert s[0, jj] == 1-i + assert s[0, jj] == 1 - i assert H[0, jj] == i assert s[1, jj] == 0 # * s = H * 2 for jj in j: - assert s[0, jj] == i*2 + assert s[0, jj] == i * 2 assert H[0, jj] == i assert s[1, jj] == 0 @@ -405,14 +422,14 @@ def test_op2(self, setup): assert s[1, jj] == 0 # ** - s = H ** 2 + s = H**2 for jj in j: assert s[0, jj] == i**2 assert H[0, jj] == i assert s[1, jj] == 0 # ** (r) - s = 2 ** H + s = 2**H for jj in j: assert s[0, jj], 2 ** H[0 == jj] assert H[0, jj] == i @@ -426,39 +443,39 @@ def test_op3(self, setup): # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) H[0, j] = i - for op in ['add', 'sub', 'mul', 'pow']: - func = getattr(H, f'__{op}__') + for op in ["add", "sub", "mul", "pow"]: + func = getattr(H, f"__{op}__") h = func(1) assert h.dtype == np.int32 - h = func(1.) + h = func(1.0) assert h.dtype == np.float64 - if op != 'pow': - h = func(1.j) + if op != "pow": + h = func(1.0j) assert h.dtype == np.complex128 H = H.copy(dtype=np.float64) - for op in ['add', 'sub', 'mul', 'pow']: - func = getattr(H, f'__{op}__') + for op in ["add", "sub", "mul", "pow"]: + func = getattr(H, f"__{op}__") h = func(1) assert h.dtype == np.float64 - h = func(1.) + h = func(1.0) assert h.dtype == np.float64 - if op != 'pow': - h = func(1.j) + if op != "pow": + h = func(1.0j) assert h.dtype == np.complex128 H = H.copy(dtype=np.complex128) - for op in ['add', 'sub', 'mul', 'pow']: - func = getattr(H, f'__{op}__') + for op in ["add", "sub", "mul", "pow"]: + func = getattr(H, f"__{op}__") h = func(1) assert h.dtype == np.complex128 - h = func(1.) + h = func(1.0) assert h.dtype == np.complex128 - if op != 'pow': - h = func(1.j) + if op != "pow": + h = func(1.0j) assert h.dtype == np.complex128 def test_op4(self, setup): @@ -466,41 +483,41 @@ def test_op4(self, setup): H = Hamiltonian(g, dtype=np.int32) # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) H[0, j] = i h = 1 + H assert h.dtype == np.int32 - h = 1. + H + h = 1.0 + H assert h.dtype == np.float64 - h = 1.j + H + h = 1.0j + H assert h.dtype == np.complex128 h = 1 - H assert h.dtype == np.int32 - h = 1. - H + h = 1.0 - H assert h.dtype == np.float64 - h = 1.j - H + h = 1.0j - H assert h.dtype == np.complex128 h = 1 * H assert h.dtype == np.int32 - h = 1. * H + h = 1.0 * H assert h.dtype == np.float64 - h = 1.j * H + h = 1.0j * H assert h.dtype == np.complex128 - h = 1 ** H + h = 1**H assert h.dtype == np.int32 - h = 1. ** H + h = 1.0**H assert h.dtype == np.float64 - h = 1.j ** H + h = 1.0j**H assert h.dtype == np.complex128 def test_untile1(self, setup): # Test of eigenvalues using a cut # Hamiltonian - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] # Create reference Hg = Hamiltonian(setup.g) @@ -519,7 +536,7 @@ def test_untile1(self, setup): def test_untile2(self, setup): # Test of eigenvalues using a cut # Hamiltonian - R, param = [0.1, 1.5], [(1., 1.), (0.1, 0.1)] + R, param = [0.1, 1.5], [(1.0, 1.0), (0.1, 0.1)] # Create reference Hg = Hamiltonian(setup.g, orthogonal=False) @@ -537,7 +554,7 @@ def test_untile2(self, setup): def test_eigh_vs_eig(self, setup): # Test of eigenvalues - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param), eta=True) @@ -549,13 +566,15 @@ def test_eigh_vs_eig(self, setup): eig1 = H.eigh([0.01] * 3, dtype=np.complex64) eig2 = np.sort(H.eig([0.01] * 3, dtype=np.complex64).real) - eig3 = np.sort(H.eig([0.01] * 3, eigvals_only=False, dtype=np.complex64)[0].real) + eig3 = np.sort( + H.eig([0.01] * 3, eigvals_only=False, dtype=np.complex64)[0].real + ) assert np.allclose(eig1, eig2, atol=1e-5) assert np.allclose(eig1, eig3, atol=1e-5) def test_eig1(self, setup): # Test of eigenvalues - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param), eta=True) @@ -568,13 +587,13 @@ def test_eig1(self, setup): def test_eig2(self, setup): # Test of eigenvalues HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((1., 1.), (0.1, 0.1))]) + HS.construct([(0.1, 1.5), ((1.0, 1.0), (0.1, 0.1))]) eig1 = HS.eigh(dtype=np.complex64) assert np.allclose(eig1, HS.eigh(dtype=np.complex128)) setup.HS.empty() def test_eig3(self, setup): - setup.HS.construct([(0.1, 1.5), ((1., 1.), (0.1, 0.1))]) + setup.HS.construct([(0.1, 1.5), ((1.0, 1.0), (0.1, 0.1))]) BS = BandStructure(setup.HS, [[0, 0, 0], [0.5, 0.5, 0]], 10) eigs = BS.apply.array.eigh() assert len(BS) == eigs.shape[0] @@ -587,9 +606,9 @@ def test_eig3(self, setup): def test_eig4(self, setup): # Test of eigenvalues vs eigenstate class HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((1., 1.), (0.1, 0.1))]) + HS.construct([(0.1, 1.5), ((1.0, 1.0), (0.1, 0.1))]) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): e, v = HS.eigh(k, eigvals_only=False) es = HS.eigenstate(k) assert np.allclose(e, es.eig) @@ -613,13 +632,14 @@ def test_eig4(self, setup): assert np.allclose(eig1, eig4, atol=1e-5) assert np.allclose(eig1, eig5, atol=1e-5) - assert es.inner(matrix=HS.Hk([0.1] * 3), ket=HS.eigenstate([0.3] * 3), - diag=False).shape == (len(es), len(es)) + assert es.inner( + matrix=HS.Hk([0.1] * 3), ket=HS.eigenstate([0.3] * 3), diag=False + ).shape == (len(es), len(es)) @pytest.mark.filterwarnings("ignore", message="*uses an overlap matrix that") def test_inner(self, setup): HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((2., 1.), (3., 0.))]) + HS.construct([(0.1, 1.5), ((2.0, 1.0), (3.0, 0.0))]) HS = HS.tile(2, 0).tile(2, 1) es1 = HS.eigenstate([0.1] * 3) @@ -640,25 +660,25 @@ def test_inner(self, setup): def test_gauge_eig(self, setup): # Test of eigenvalues - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) k = [0.1] * 3 - es1 = H.eigenstate(k, gauge='R') - es2 = H.eigenstate(k, gauge='r') + es1 = H.eigenstate(k, gauge="R") + es2 = H.eigenstate(k, gauge="r") assert np.allclose(es1.eig, es2.eig) assert not np.allclose(es1.state, es2.state) - es1 = H.eigenstate(k, gauge='R', dtype=np.complex64) - es2 = H.eigenstate(k, gauge='r', dtype=np.complex64) + es1 = H.eigenstate(k, gauge="R", dtype=np.complex64) + es2 = H.eigenstate(k, gauge="r", dtype=np.complex64) assert np.allclose(es1.eig, es2.eig) assert not np.allclose(es1.state, es2.state) def test_eigenstate_ipr(self, setup): # Test of eigenvalues - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) @@ -670,7 +690,7 @@ def test_eigenstate_ipr(self, setup): def test_eigenstate_tile(self, setup): # Test of eigenvalues - R, param = [0.1, 1.5], [0., 2.7] + R, param = [0.1, 1.5], [0.0, 2.7] H1 = setup.H.copy() H1.construct((R, param)) H2 = H1.tile(2, 1) @@ -678,11 +698,11 @@ def test_eigenstate_tile(self, setup): k = [0] * 3 # we must select a k that does not fold on # itself which then creates degenerate states - for k1 in [0.5, 1/3]: + for k1 in [0.5, 1 / 3]: k[1] = k1 es1 = H1.eigenstate(k) es1_2 = es1.tile(2, 1, normalize=True) - es2 = H2.eigenstate(es1_2.info['k']) + es2 = H2.eigenstate(es1_2.info["k"]) # we need to check that these are somewhat the same out = es1_2.inner(es2, diag=False) @@ -691,7 +711,7 @@ def test_eigenstate_tile(self, setup): def test_eigenstate_tile_offset(self, setup): # Test of eigenvalues - R, param = [0.1, 1.5], [0., 2.7] + R, param = [0.1, 1.5], [0.0, 2.7] H1 = setup.H.copy() H1.construct((R, param)) H2 = H1.tile(2, 1) @@ -699,11 +719,11 @@ def test_eigenstate_tile_offset(self, setup): k = [0] * 3 # we must select a k that does not fold on # itself which then creates degenerate states - for k1 in [0.5, 1/3]: + for k1 in [0.5, 1 / 3]: k[1] = k1 es1 = H1.eigenstate(k) es1_2 = es1.tile(2, 1, normalize=True, offset=1) - es2 = H2.eigenstate(es1_2.info['k']).translate([0, 1, 0]) + es2 = H2.eigenstate(es1_2.info["k"]).translate([0, 1, 0]) # we need to check that these are somewhat the same out = es1_2.inner(es2, diag=False) @@ -712,25 +732,23 @@ def test_eigenstate_tile_offset(self, setup): def test_eigenstate_translate(self, setup): # Test of eigenvalues - R, param = [0.1, 1.5], [0., 2.7] + R, param = [0.1, 1.5], [0.0, 2.7] H = setup.H.copy() H.construct((R, param)) k = [0] * 3 # we must select a k that does not fold on # itself which then creates degenerate states - for k1 in [0.5, 1/3]: + for k1 in [0.5, 1 / 3]: k[1] = k1 es1 = H.eigenstate(k) es1_2 = es1.tile(2, 1) es2 = es1.translate([0, 1, 0]) - assert np.allclose(es1_2.state[:, :len(es1)], - es1.state) - assert np.allclose(es1_2.state[:, len(es1):], - es2.state) + assert np.allclose(es1_2.state[:, : len(es1)], es1.state) + assert np.allclose(es1_2.state[:, len(es1) :], es2.state) def test_gauge_velocity(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) @@ -742,16 +760,16 @@ def test_gauge_velocity(self, setup): # This test is the reason why a default degenerate=1e-5 is used # since the gauge='R' yields the correct *decoupled* states # where as gauge='r' mixes them in a bad way. - es1 = H.eigenstate(k, gauge='R') - es2 = H.eigenstate(k, gauge='r') + es1 = H.eigenstate(k, gauge="R") + es2 = H.eigenstate(k, gauge="r") assert np.allclose(es1.velocity(), es2.velocity()) assert np.allclose(es1.velocity(), es2.velocity(degenerate_dir=(1, 1, 0))) - es2.change_gauge('R') + es2.change_gauge("R") assert np.allclose(es1.velocity(), es2.velocity()) - es2.change_gauge('r') - es1.change_gauge('r') + es2.change_gauge("r") + es1.change_gauge("r") v1 = es1.velocity() v2 = es2.velocity() assert np.allclose(v1, v2) @@ -763,7 +781,7 @@ def test_gauge_velocity(self, setup): assert np.allclose(np.diagonal(vv2, axis1=1, axis2=2), v1) def test_derivative_orthogonal(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) @@ -775,7 +793,7 @@ def test_derivative_orthogonal(self, setup): assert np.allclose(v1, v) def test_derivative_non_orthogonal(self, setup): - R, param = [0.1, 1.5], [(1., 1.), (0.1, 0.1)] + R, param = [0.1, 1.5], [(1.0, 1.0), (0.1, 0.1)] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g, orthogonal=False) H.construct((R, param)) @@ -787,45 +805,47 @@ def test_derivative_non_orthogonal(self, setup): assert np.allclose(v1, v) def test_berry_phase(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) - bz = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1/3] * 3) + bz = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1 / 3] * 3) berry_phase(bz) berry_phase(bz, sub=0) - berry_phase(bz, eigvals=True, sub=0, method='berry:svd') + berry_phase(bz, eigvals=True, sub=0, method="berry:svd") def test_berry_phase_fail_sc(self, setup): g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) - bz = BandStructure.param_circle(H.geometry.lattice, 20, 0.01, [0, 0, 1], [1/3] * 3) + bz = BandStructure.param_circle( + H.geometry.lattice, 20, 0.01, [0, 0, 1], [1 / 3] * 3 + ) with pytest.raises(SislError): berry_phase(bz) def test_berry_phase_loop(self, setup): g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) - bz1 = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1/3] * 3) - bz2 = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1/3] * 3, loop=True) + bz1 = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1 / 3] * 3) + bz2 = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1 / 3] * 3, loop=True) assert np.allclose(berry_phase(bz1), berry_phase(bz2)) def test_berry_phase_non_orthogonal(self, setup): - R, param = [0.1, 1.5], [(1., 1.), (0.1, 0.1)] + R, param = [0.1, 1.5], [(1.0, 1.0), (0.1, 0.1)] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g, orthogonal=False) H.construct((R, param)) - bz = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1/3] * 3) + bz = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1 / 3] * 3) berry_phase(bz) def test_berry_phase_orthogonal_spin_down(self, setup): - R, param = [0.1, 1.5], [(1., 1.), (0.1, 0.2)] + R, param = [0.1, 1.5], [(1.0, 1.0), (0.1, 0.2)] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g, spin=Spin.POLARIZED) H.construct((R, param)) - bz = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1/3] * 3) + bz = BandStructure.param_circle(H, 20, 0.01, [0, 0, 1], [1 / 3] * 3) bp1 = berry_phase(bz) bp2 = berry_phase(bz, eigenstate_kwargs={"spin": 1}) assert bp1 != bp2 @@ -836,14 +856,16 @@ def test_berry_phase_zak_x_topological(self): g = Geometry([[0, 0, 0], [1.2, 0, 0]], Atom(1, 1.001), lattice=[2, 10, 10]) g.set_nsc([3, 1, 1]) H = Hamiltonian(g) - H.construct([(0.1, 1.0, 1.5), (0, 1., 0.5)]) + H.construct([(0.1, 1.0, 1.5), (0, 1.0, 0.5)]) + # Contour def func(parent, N, i): - return [i/N, 0, 0] + return [i / N, 0, 0] + bz = BrillouinZone.parametrize(H, func, 101) - assert np.allclose(np.abs(berry_phase(bz, sub=0, method='zak')), np.pi) + assert np.allclose(np.abs(berry_phase(bz, sub=0, method="zak")), np.pi) # Just to do the other branch - berry_phase(bz, method='zak') + berry_phase(bz, method="zak") def test_berry_phase_zak_x_topological_non_orthogonal(self): # SSH model, topological cell @@ -851,14 +873,16 @@ def test_berry_phase_zak_x_topological_non_orthogonal(self): g = Geometry([[0, 0, 0], [1.2, 0, 0]], Atom(1, 1.001), lattice=[2, 10, 10]) g.set_nsc([3, 1, 1]) H = Hamiltonian(g, orthogonal=False) - H.construct([(0.1, 1.0, 1.5), ((0, 1), (1., 0.25), (0.5, 0.1))]) + H.construct([(0.1, 1.0, 1.5), ((0, 1), (1.0, 0.25), (0.5, 0.1))]) + # Contour def func(parent, N, i): - return [i/N, 0, 0] + return [i / N, 0, 0] + bz = BrillouinZone.parametrize(H, func, 101) - assert np.allclose(np.abs(berry_phase(bz, sub=0, method='zak')), np.pi) + assert np.allclose(np.abs(berry_phase(bz, sub=0, method="zak")), np.pi) # Just to do the other branch - berry_phase(bz, method='zak') + berry_phase(bz, method="zak") def test_berry_phase_zak_x_trivial(self): # SSH model, trivial cell @@ -866,67 +890,75 @@ def test_berry_phase_zak_x_trivial(self): g = Geometry([[0, 0, 0], [1.2, 0, 0]], Atom(1, 1.001), lattice=[2, 10, 10]) g.set_nsc([3, 1, 1]) H = Hamiltonian(g) - H.construct([(0.1, 1.0, 1.5), (0, 0.5, 1.)]) + H.construct([(0.1, 1.0, 1.5), (0, 0.5, 1.0)]) + # Contour def func(parent, N, i): - return [i/N, 0, 0] + return [i / N, 0, 0] + bz = BrillouinZone.parametrize(H, func, 101) - assert np.allclose(np.abs(berry_phase(bz, sub=0, method='zak')), 0.) + assert np.allclose(np.abs(berry_phase(bz, sub=0, method="zak")), 0.0) # Just to do the other branch - berry_phase(bz, method='zak') + berry_phase(bz, method="zak") def test_berry_phase_zak_y(self): # SSH model, topological cell - g = Geometry([[0, -.6, 0], [0, 0.6, 0]], Atom(1, 1.001), lattice=[10, 2, 10]) + g = Geometry([[0, -0.6, 0], [0, 0.6, 0]], Atom(1, 1.001), lattice=[10, 2, 10]) g.set_nsc([1, 3, 1]) H = Hamiltonian(g) - H.construct([(0.1, 1.0, 1.5), (0, 1., 0.5)]) + H.construct([(0.1, 1.0, 1.5), (0, 1.0, 0.5)]) + # Contour def func(parent, N, i): - return [0, i/N, 0] + return [0, i / N, 0] + bz = BrillouinZone.parametrize(H, func, 101) - assert np.allclose(np.abs(berry_phase(bz, sub=0, method='zak')), np.pi) + assert np.allclose(np.abs(berry_phase(bz, sub=0, method="zak")), np.pi) # Just to do the other branch - berry_phase(bz, method='zak') + berry_phase(bz, method="zak") def test_berry_phase_zak_offset(self): # SSH model, topological cell - g = Geometry([[0., 0, 0], [1.2, 0, 0]], Atom(1, 1.001), lattice=[2, 10, 10]) + g = Geometry([[0.0, 0, 0], [1.2, 0, 0]], Atom(1, 1.001), lattice=[2, 10, 10]) g.set_nsc([3, 1, 1]) H = Hamiltonian(g) - H.construct([(0.1, 1.0, 1.5), (0, 1., 0.5)]) + H.construct([(0.1, 1.0, 1.5), (0, 1.0, 0.5)]) + # Contour def func(parent, N, i): - return [i/N, 0, 0] + return [i / N, 0, 0] + bz = BrillouinZone.parametrize(H, func, 101) - zak = berry_phase(bz, sub=0, method='zak') + zak = berry_phase(bz, sub=0, method="zak") assert np.allclose(np.abs(zak), np.pi) def test_berry_phase_method_fail(self): # wrong method keyword - g = Geometry([[-.6, 0, 0], [0.6, 0, 0]], Atom(1, 1.001), lattice=[2, 10, 10]) + g = Geometry([[-0.6, 0, 0], [0.6, 0, 0]], Atom(1, 1.001), lattice=[2, 10, 10]) g.set_nsc([3, 1, 1]) H = Hamiltonian(g) + def func(parent, N, i): - return [0, i/N, 0] + return [0, i / N, 0] + bz = BrillouinZone.parametrize(H, func, 101) with pytest.raises(ValueError): - berry_phase(bz, method='unknown') + berry_phase(bz, method="unknown") def test_berry_curvature(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) k = [0.1] * 3 - ie1 = H.eigenstate(k, gauge='R').berry_curvature() - ie2 = H.eigenstate(k, gauge='r').berry_curvature(degenerate_dir=(1, 1, 0)) + ie1 = H.eigenstate(k, gauge="R").berry_curvature() + ie2 = H.eigenstate(k, gauge="r").berry_curvature(degenerate_dir=(1, 1, 0)) assert np.allclose(ie1, ie2) - @pytest.mark.filterwarnings('ignore', category=np.ComplexWarning) + @pytest.mark.filterwarnings("ignore", category=np.ComplexWarning) def test_conductivity(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) @@ -934,9 +966,9 @@ def test_conductivity(self, setup): mp = MonkhorstPack(H, [11, 11, 1]) cond = conductivity(mp) - @pytest.mark.filterwarnings('ignore', category=np.ComplexWarning) + @pytest.mark.filterwarnings("ignore", category=np.ComplexWarning) def test_conductivity_spin(self, setup): - R, param = [0.1, 1.5], [[1., 2.], [0.1, 0.2]] + R, param = [0.1, 1.5], [[1.0, 2.0], [0.1, 0.2]] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g, spin=Spin.POLARIZED) H.construct((R, param)) @@ -947,59 +979,59 @@ def test_conductivity_spin(self, setup): @pytest.mark.xfail(reason="Gauges make different decouplings") def test_gauge_eff(self, setup): # it is not fully clear to me why they are different - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] g = setup.g.tile(2, 0).tile(2, 1).tile(2, 2) H = Hamiltonian(g) H.construct((R, param)) k = [0.1] * 3 - ie1 = H.eigenstate(k, gauge='R').effective_mass() - ie2 = H.eigenstate(k, gauge='r').effective_mass() + ie1 = H.eigenstate(k, gauge="R").effective_mass() + ie2 = H.eigenstate(k, gauge="r").effective_mass() assert np.allclose(abs(ie1), abs(ie2)) def test_eigenstate_polarized_orthogonal_sk(self, setup): - R, param = [0.1, 1.5], [1., [0.1, 0.1]] - H = Hamiltonian(setup.g, spin='P') + R, param = [0.1, 1.5], [1.0, [0.1, 0.1]] + H = Hamiltonian(setup.g, spin="P") H.construct((R, param)) k = [0.1] * 3 - ie1 = H.eigenstate(k, spin=0, format='array').Sk() - ie2 = H.eigenstate(k, spin=1, format='array').Sk() - assert np.allclose(ie1.dot(1.), 1.) - assert np.allclose(ie2.dot(1.), 1.) + ie1 = H.eigenstate(k, spin=0, format="array").Sk() + ie2 = H.eigenstate(k, spin=1, format="array").Sk() + assert np.allclose(ie1.dot(1.0), 1.0) + assert np.allclose(ie2.dot(1.0), 1.0) def test_eigenstate_polarized_non_ortogonal_sk(self, setup): - R, param = [0.1, 1.5], [[1., 1., 1.], [0.1, 0.1, 0.05]] - H = Hamiltonian(setup.g, spin='P', orthogonal=False) + R, param = [0.1, 1.5], [[1.0, 1.0, 1.0], [0.1, 0.1, 0.05]] + H = Hamiltonian(setup.g, spin="P", orthogonal=False) H.construct((R, param)) k = [0.1] * 3 - ie1 = H.eigenstate(k, spin=0, format='array').Sk() - ie2 = H.eigenstate(k, spin=1, format='array').Sk() + ie1 = H.eigenstate(k, spin=0, format="array").Sk() + ie2 = H.eigenstate(k, spin=1, format="array").Sk() assert np.allclose(ie1, ie2) def test_change_gauge(self, setup): # Test of eigenvalues vs eigenstate class HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((1., 1.), (0.1, 0.1))]) + HS.construct([(0.1, 1.5), ((1.0, 1.0), (0.1, 0.1))]) es = HS.eigenstate() es2 = es.copy() - es2.change_gauge('r') + es2.change_gauge("r") assert np.allclose(es2.state, es.state) es = HS.eigenstate(k=(0.2, 0.2, 0.2)) es2 = es.copy() - es2.change_gauge('r') + es2.change_gauge("r") assert not np.allclose(es2.state, es.state) - es2.change_gauge('R') + es2.change_gauge("R") assert np.allclose(es2.state, es.state) def test_expectation_value(self, setup): H = setup.H.copy() - H.construct([(0.1, 1.5), ((1., 1.))]) + H.construct([(0.1, 1.5), ((1.0, 1.0))]) D = np.ones(len(H)) I = np.identity(len(H)) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = H.eigenstate(k) d = es.inner(matrix=D) @@ -1014,20 +1046,20 @@ def test_expectation_value(self, setup): def test_velocity_orthogonal(self, setup): H = setup.H.copy() - H.construct([(0.1, 1.5), ((1., 1.))]) + H.construct([(0.1, 1.5), ((1.0, 1.0))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = H.eigenstate(k) v = es.velocity() vsub = es.sub([0]).velocity()[:, 0] assert np.allclose(v[:, 0], vsub) - @pytest.mark.filterwarnings('ignore', category=np.ComplexWarning) + @pytest.mark.filterwarnings("ignore", category=np.ComplexWarning) def test_velocity_nonorthogonal(self, setup): HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((1., 1.), (0.1, 0.1))]) + HS.construct([(0.1, 1.5), ((1.0, 1.0), (0.1, 0.1))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = HS.eigenstate(k) v = es.velocity() vsub = es.sub([0]).velocity() @@ -1035,20 +1067,20 @@ def test_velocity_nonorthogonal(self, setup): def test_velocity_matrix_orthogonal(self, setup): H = setup.H.copy() - H.construct([(0.1, 1.5), ((1., 1.))]) + H.construct([(0.1, 1.5), ((1.0, 1.0))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = H.eigenstate(k) v = es.velocity(matrix=True) vsub = es.sub([0, 1]).velocity(matrix=True) assert np.allclose(v[:, :2, :2], vsub) - @pytest.mark.filterwarnings('ignore', category=np.ComplexWarning) + @pytest.mark.filterwarnings("ignore", category=np.ComplexWarning) def test_velocity_matrix_nonorthogonal(self, setup): HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((1., 1.), (0.1, 0.1))]) + HS.construct([(0.1, 1.5), ((1.0, 1.0), (0.1, 0.1))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = HS.eigenstate(k) v = es.velocity(matrix=True) vsub = es.sub([0, 1]).velocity(matrix=True) @@ -1056,24 +1088,24 @@ def test_velocity_matrix_nonorthogonal(self, setup): def test_dos1(self, setup): HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((1., 1.), (0.1, 0.1))]) + HS.construct([(0.1, 1.5), ((1.0, 1.0), (0.1, 0.1))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = HS.eigenstate(k) DOS = es.DOS(E) - assert DOS.dtype.kind == 'f' + assert DOS.dtype.kind == "f" assert np.allclose(es.norm2(), 1) str(es) def test_pdos1(self, setup): HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((0., 1.), (1., 0.1))]) + HS.construct([(0.1, 1.5), ((0.0, 1.0), (1.0, 0.1))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = HS.eigenstate(k) - DOS = es.DOS(E, 'lorentzian') - PDOS = es.PDOS(E, 'lorentzian') - assert PDOS.dtype.kind == 'f' + DOS = es.DOS(E, "lorentzian") + PDOS = es.PDOS(E, "lorentzian") + assert PDOS.dtype.kind == "f" assert PDOS.shape[0] == 1 assert PDOS.shape[1] == len(HS) assert PDOS.shape[2] == len(E) @@ -1081,13 +1113,13 @@ def test_pdos1(self, setup): def test_pdos2(self, setup): H = setup.H.copy() - H.construct([(0.1, 1.5), (0., 0.1)]) + H.construct([(0.1, 1.5), (0.0, 0.1)]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = H.eigenstate(k) DOS = es.DOS(E) PDOS = es.PDOS(E) - assert PDOS.dtype.kind == 'f' + assert PDOS.dtype.kind == "f" assert np.allclose(PDOS.sum(1), DOS) def test_pdos3(self, setup): @@ -1095,7 +1127,7 @@ def test_pdos3(self, setup): # In this case we will assume an orthogonal # basis, however, the basis is not orthogonal. HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((0., 1.), (1., 0.1))]) + HS.construct([(0.1, 1.5), ((0.0, 1.0), (1.0, 0.1))]) E = np.linspace(-4, 4, 21) es = HS.eigenstate() es.parent = None @@ -1109,19 +1141,19 @@ def test_pdos4(self, setup): # basis. If the basis *is* orthogonal, then # regardless of k, the PDOS will be correct. H = setup.H.copy() - H.construct([(0.1, 1.5), (0., 0.1)]) + H.construct([(0.1, 1.5), (0.0, 0.1)]) E = np.linspace(-4, 4, 21) es = H.eigenstate() es.parent = None DOS = es.DOS(E) PDOS = es.PDOS(E) - assert PDOS.dtype.kind == 'f' + assert PDOS.dtype.kind == "f" assert np.allclose(PDOS.sum(1), DOS) es = H.eigenstate([0.25] * 3) DOS = es.DOS(E) es.parent = None PDOS = es.PDOS(E) - assert PDOS.dtype.kind == 'f' + assert PDOS.dtype.kind == "f" assert np.allclose(PDOS.sum(1), DOS) def test_pdos_nc(self): @@ -1131,6 +1163,7 @@ def test_pdos_nc(self): # this should be Hermitian H[0, 0] = np.array([1, 2, 3, 4]) E = [0] + def dist(E, *args): return np.ones(len(E)) @@ -1166,6 +1199,7 @@ def test_pdos_so(self): # this should be Hermitian H[0, 0] = np.array([1, 2, 3, 4, 0, 0, 3, -4]) E = [0] + def dist(E, *args): return np.ones(len(E)) @@ -1196,38 +1230,38 @@ def dist(E, *args): def test_coop_against_pdos_nonortho(self, setup): HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((0., 1.), (1., 0.1))]) + HS.construct([(0.1, 1.5), ((0.0, 1.0), (1.0, 0.1))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = HS.eigenstate(k) - COOP = es.COOP(E, 'lorentzian') + COOP = es.COOP(E, "lorentzian") - DOS = es.DOS(E, 'lorentzian') + DOS = es.DOS(E, "lorentzian") COOP2DOS = np.array([C.sum() for C in COOP]) assert DOS.shape == COOP2DOS.shape assert np.allclose(DOS, COOP2DOS) # This one returns sparse matrices, so we have to # deal with that. - DOS = es.PDOS(E, 'lorentzian')[0] + DOS = es.PDOS(E, "lorentzian")[0] COOP2DOS = np.array([C.sum(1).A.ravel() for C in COOP]).T assert DOS.shape == COOP2DOS.shape assert np.allclose(DOS, COOP2DOS) def test_coop_against_pdos_ortho(self, setup): H = setup.H.copy() - H.construct([(0.1, 1.5), (0., 1.)]) + H.construct([(0.1, 1.5), (0.0, 1.0)]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = H.eigenstate(k) - COOP = es.COOP(E, 'lorentzian') + COOP = es.COOP(E, "lorentzian") - DOS = es.DOS(E, 'lorentzian') + DOS = es.DOS(E, "lorentzian") COOP2DOS = np.array([C.sum() for C in COOP]) assert DOS.shape == COOP2DOS.shape assert np.allclose(DOS, COOP2DOS) - DOS = es.PDOS(E, 'lorentzian') + DOS = es.PDOS(E, "lorentzian") # matrix.A1 is np.array(matrix).ravel() COOP2DOS = np.array([C.sum(1).A1 for C in COOP]).T assert DOS.shape[1:] == COOP2DOS.shape @@ -1235,60 +1269,72 @@ def test_coop_against_pdos_ortho(self, setup): def test_coop_sp_vs_np(self, setup): HS = setup.HS.copy() - HS.construct([(0.1, 1.5), ((0., 1.), (1., 0.1))]) + HS.construct([(0.1, 1.5), ((0.0, 1.0), (1.0, 0.1))]) E = np.linspace(-4, 4, 21) - for k in ([0] *3, [0.2] * 3): + for k in ([0] * 3, [0.2] * 3): es = HS.eigenstate(k) - COOP_sp = es.COOP(E, 'lorentzian') + COOP_sp = es.COOP(E, "lorentzian") assert issparse(COOP_sp[0]) - es = HS.eigenstate(k, format='array') - COOP_np = es.COOP(E, 'lorentzian') + es = HS.eigenstate(k, format="array") + COOP_np = es.COOP(E, "lorentzian") assert isinstance(COOP_np[0], np.ndarray) for c_sp, c_np in zip(COOP_sp, COOP_np): assert np.allclose(c_sp.toarray(), c_np) def test_spin1(self, setup): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.int32, spin=Spin.POLARIZED) for i in range(10): - j = range(i*2, i*2+3) - H[0, j] = (i, i*2) + j = range(i * 2, i * 2 + 3) + H[0, j] = (i, i * 2) H2 = Hamiltonian(g, 2, dtype=np.int32) for i in range(10): - j = range(i*2, i*2+3) - H2[0, j] = (i, i*2) + j = range(i * 2, i * 2 + 3) + H2[0, j] = (i, i * 2) assert H.spsame(H2) def test_spin2(self, setup): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.int32, spin=Spin.POLARIZED) for i in range(10): - j = range(i*2, i*2+3) - H[0, j] = (i, i*2) + j = range(i * 2, i * 2 + 3) + H[0, j] = (i, i * 2) H2 = Hamiltonian(g, 2, dtype=np.int32) for i in range(10): - j = range(i*2, i*2+3) - H2[0, j] = (i, i*2) + j = range(i * 2, i * 2 + 3) + H2[0, j] = (i, i * 2) assert H.spsame(H2) H2 = Hamiltonian(g, Spin(Spin.POLARIZED), dtype=np.int32) for i in range(10): - j = range(i*2, i*2+3) - H2[0, j] = (i, i*2) + j = range(i * 2, i * 2 + 3) + H2[0, j] = (i, i * 2) assert H.spsame(H2) - H2 = Hamiltonian(g, Spin('polarized'), dtype=np.int32) + H2 = Hamiltonian(g, Spin("polarized"), dtype=np.int32) for i in range(10): - j = range(i*2, i*2+3) - H2[0, j] = (i, i*2) + j = range(i * 2, i * 2 + 3) + H2[0, j] = (i, i * 2) assert H.spsame(H2) def test_transform_up(self): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, spin=Spin.UNPOLARIZED) for i in range(10): H[0, i] = i + 0.1 @@ -1307,10 +1353,14 @@ def test_transform_up(self): assert np.abs(Hcsr[0] - Ht.tocsr(1)).sum() == 0 def test_transform_up_nonortho(self): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, spin=Spin.UNPOLARIZED, orthogonal=False) for i in range(10): - H[0, i] = (i + 0.1, 1.) + H[0, i] = (i + 0.1, 1.0) Hcsr = [H.tocsr(i) for i in range(H.shape[2])] Ht = H.transform(spin=Spin.POLARIZED) @@ -1334,7 +1384,11 @@ def test_transform_up_nonortho(self): assert np.abs(Hcsr[-1] - Ht.tocsr(-1)).sum() == 0 def test_transform_down(self): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, spin=Spin.SPINORBIT) for i in range(10): for j in range(8): @@ -1355,16 +1409,20 @@ def test_transform_down(self): assert np.abs(Hcsr[3] - Ht.tocsr(3)).sum() == 0 def test_transform_down_nonortho(self): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, spin=Spin.SPINORBIT, orthogonal=False) for i in range(10): for j in range(8): H[0, i, j] = i + 0.1 + j - H[0, i, -1] = 1. + H[0, i, -1] = 1.0 Hcsr = [H.tocsr(i) for i in range(H.shape[2])] Ht = H.transform(spin=Spin.UNPOLARIZED) - assert np.abs(0.5 * Hcsr[0]+ 0.5 * Hcsr[1] - Ht.tocsr(0)).sum() == 0 + assert np.abs(0.5 * Hcsr[0] + 0.5 * Hcsr[1] - Ht.tocsr(0)).sum() == 0 assert np.abs(Hcsr[-1] - Ht.tocsr(-1)).sum() == 0 Ht = H.transform(spin=Spin.POLARIZED) @@ -1381,10 +1439,14 @@ def test_transform_down_nonortho(self): @pytest.mark.parametrize("k", [[0, 0, 0], [0.1, 0, 0]]) def test_spin_squared(self, setup, k): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(1, nsc=[3, 1, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(1, nsc=[3, 1, 1]), + ) H = Hamiltonian(g, spin=Spin.POLARIZED) H.construct(([0.1, 1.1], [[0, 0.1], [1, 1.1]])) - H[0, 0] = (0.1, 0.) + H[0, 0] = (0.1, 0.0) H[0, 1] = (0.5, 0.4) es_alpha = H.eigenstate(k, spin=0) es_beta = H.eigenstate(k, spin=1) @@ -1399,60 +1461,72 @@ def test_spin_squared(self, setup, k): assert len(sup) == 2 assert len(sdn) == es_beta.shape[0] - sup, sdn = spin_squared(es_alpha.sub(range(3)).state, es_beta.sub(range(2)).state) + sup, sdn = spin_squared( + es_alpha.sub(range(3)).state, es_beta.sub(range(2)).state + ) assert sup.sum() == pytest.approx(sdn.sum()) assert len(sup) == 3 assert len(sdn) == 2 - sup, sdn = spin_squared(es_alpha.sub(0).state.ravel(), es_beta.sub(range(2)).state) + sup, sdn = spin_squared( + es_alpha.sub(0).state.ravel(), es_beta.sub(range(2)).state + ) assert sup.sum() == pytest.approx(sdn.sum()) assert sup.ndim == 1 assert len(sup) == 1 assert len(sdn) == 2 - sup, sdn = spin_squared(es_alpha.sub(0).state.ravel(), es_beta.sub(0).state.ravel()) + sup, sdn = spin_squared( + es_alpha.sub(0).state.ravel(), es_beta.sub(0).state.ravel() + ) assert sup.sum() == pytest.approx(sdn.sum()) assert sup.ndim == 0 assert sdn.ndim == 0 - sup, sdn = spin_squared(es_alpha.sub(range(2)).state, es_beta.sub(0).state.ravel()) + sup, sdn = spin_squared( + es_alpha.sub(range(2)).state, es_beta.sub(0).state.ravel() + ) assert sup.sum() == pytest.approx(sdn.sum()) assert len(sup) == 2 assert len(sdn) == 1 def test_non_colinear_orthogonal(self, setup): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, spin=Spin.NONCOLINEAR) for i in range(10): - j = range(i*2, i*2+3) + j = range(i * 2, i * 2 + 3) H[i, i, 0] = 0.05 H[i, i, 1] = 0.1 H[i, i, 2] = 0.1 H[i, i, 3] = 0.1 if i > 0: - H[i, i-1, 0] = 1. - H[i, i-1, 1] = 1. + H[i, i - 1, 0] = 1.0 + H[i, i - 1, 1] = 1.0 if i < 9: - H[i, i+1, 0] = 1. - H[i, i+1, 1] = 1. + H[i, i + 1, 0] = 1.0 + H[i, i + 1, 1] = 1.0 eig1 = H.eigh(dtype=np.complex64) assert np.allclose(H.eigh(dtype=np.complex128), eig1) - assert np.allclose(H.eigh(gauge='r', dtype=np.complex128), eig1) + assert np.allclose(H.eigh(gauge="r", dtype=np.complex128), eig1) assert len(eig1) == len(H) - H1 = Hamiltonian(g, dtype=np.float64, spin=Spin('non-collinear')) + H1 = Hamiltonian(g, dtype=np.float64, spin=Spin("non-collinear")) for i in range(10): - j = range(i*2, i*2+3) + j = range(i * 2, i * 2 + 3) H1[i, i, 0] = 0.05 H1[i, i, 1] = 0.1 H1[i, i, 2] = 0.1 H1[i, i, 3] = 0.1 if i > 0: - H1[i, i-1, 0] = 1. - H1[i, i-1, 1] = 1. + H1[i, i - 1, 0] = 1.0 + H1[i, i - 1, 1] = 1.0 if i < 9: - H1[i, i+1, 0] = 1. - H1[i, i+1, 1] = 1. + H1[i, i + 1, 0] = 1.0 + H1[i, i + 1, 1] = 1.0 assert H1.spsame(H) eig1 = H1.eigh(dtype=np.complex64) assert np.allclose(H1.eigh(dtype=np.complex128), eig1) @@ -1488,43 +1562,49 @@ def test_non_colinear_orthogonal(self, setup): assert np.allclose(np.diagonal(vv).T, v) # Ensure we can change gauge for NC stuff - es.change_gauge('R') - es.change_gauge('r') + es.change_gauge("R") + es.change_gauge("r") def test_non_colinear_non_orthogonal(self): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, orthogonal=False, spin=Spin.NONCOLINEAR) for i in range(10): - j = range(i*2, i*2+3) + j = range(i * 2, i * 2 + 3) H[i, i, 0] = 0.1 H[i, i, 1] = 0.05 H[i, i, 2] = 0.1 H[i, i, 3] = 0.1 if i > 0: - H[i, i-1, 0] = 1. - H[i, i-1, 1] = 1. + H[i, i - 1, 0] = 1.0 + H[i, i - 1, 1] = 1.0 if i < 9: - H[i, i+1, 0] = 1. - H[i, i+1, 1] = 1. - H.S[i, i] = 1. + H[i, i + 1, 0] = 1.0 + H[i, i + 1, 1] = 1.0 + H.S[i, i] = 1.0 eig1 = H.eigh(dtype=np.complex64) assert np.allclose(H.eigh(dtype=np.complex128), eig1) assert len(eig1) == len(H) - H1 = Hamiltonian(g, dtype=np.float64, orthogonal=False, spin=Spin('non-collinear')) + H1 = Hamiltonian( + g, dtype=np.float64, orthogonal=False, spin=Spin("non-collinear") + ) for i in range(10): - j = range(i*2, i*2+3) + j = range(i * 2, i * 2 + 3) H1[i, i, 0] = 0.1 H1[i, i, 1] = 0.05 H1[i, i, 2] = 0.1 H1[i, i, 3] = 0.1 if i > 0: - H1[i, i-1, 0] = 1. - H1[i, i-1, 1] = 1. + H1[i, i - 1, 0] = 1.0 + H1[i, i - 1, 1] = 1.0 if i < 9: - H1[i, i+1, 0] = 1. - H1[i, i+1, 1] = 1. - H1.S[i, i] = 1. + H1[i, i + 1, 0] = 1.0 + H1[i, i + 1, 1] = 1.0 + H1.S[i, i] = 1.0 assert H1.spsame(H) eig1 = H1.eigh(dtype=np.complex64) assert np.allclose(H1.eigh(dtype=np.complex128), eig1) @@ -1551,14 +1631,18 @@ def test_non_colinear_non_orthogonal(self): assert np.allclose(np.diagonal(vv).T, v) # Ensure we can change gauge for NC stuff - es.change_gauge('R') - es.change_gauge('r') + es.change_gauge("R") + es.change_gauge("r") def test_spin_orbit_orthogonal(self): - g = Geometry([[i, 0, 0] for i in range(10)], Atom(6, R=1.01), lattice=Lattice(100, nsc=[3, 3, 1])) + g = Geometry( + [[i, 0, 0] for i in range(10)], + Atom(6, R=1.01), + lattice=Lattice(100, nsc=[3, 3, 1]), + ) H = Hamiltonian(g, dtype=np.float64, spin=Spin.SPINORBIT) for i in range(10): - j = range(i*2, i*2+3) + j = range(i * 2, i * 2 + 3) H[i, i, 0] = 0.1 H[i, i, 1] = 0.05 H[i, i, 2] = 0.1 @@ -1568,18 +1652,18 @@ def test_spin_orbit_orthogonal(self): H[i, i, 6] = 0.1 H[i, i, 7] = 0.1 if i > 0: - H[i, i-1, 0] = 1. - H[i, i-1, 1] = 1. + H[i, i - 1, 0] = 1.0 + H[i, i - 1, 1] = 1.0 if i < 9: - H[i, i+1, 0] = 1. - H[i, i+1, 1] = 1. + H[i, i + 1, 0] = 1.0 + H[i, i + 1, 1] = 1.0 eig1 = H.eigh(dtype=np.complex64) assert np.allclose(H.eigh(dtype=np.complex128), eig1) assert len(H.eigh()) == len(H) - H1 = Hamiltonian(g, dtype=np.float64, spin=Spin('spin-orbit')) + H1 = Hamiltonian(g, dtype=np.float64, spin=Spin("spin-orbit")) for i in range(10): - j = range(i*2, i*2+3) + j = range(i * 2, i * 2 + 3) H1[i, i, 0] = 0.1 H1[i, i, 1] = 0.05 H1[i, i, 2] = 0.1 @@ -1589,15 +1673,17 @@ def test_spin_orbit_orthogonal(self): H1[i, i, 6] = 0.1 H1[i, i, 7] = 0.1 if i > 0: - H1[i, i-1, 0] = 1. - H1[i, i-1, 1] = 1. + H1[i, i - 1, 0] = 1.0 + H1[i, i - 1, 1] = 1.0 if i < 9: - H1[i, i+1, 0] = 1. - H1[i, i+1, 1] = 1. + H1[i, i + 1, 0] = 1.0 + H1[i, i + 1, 1] = 1.0 assert H1.spsame(H) eig1 = H1.eigh(dtype=np.complex64) assert np.allclose(H1.eigh(dtype=np.complex128), eig1, atol=1e-5) - assert np.allclose(H.eigh(dtype=np.complex64), H1.eigh(dtype=np.complex128), atol=1e-5) + assert np.allclose( + H.eigh(dtype=np.complex64), H1.eigh(dtype=np.complex128), atol=1e-5 + ) # Create the block matrix for expectation SZ = block_diag(*([H1.spin.Z] * H1.no)) @@ -1627,19 +1713,19 @@ def test_spin_orbit_orthogonal(self): assert np.allclose(np.diagonal(vv).T, v) # Ensure we can change gauge for SO stuff - es.change_gauge('R') - es.change_gauge('r') + es.change_gauge("R") + es.change_gauge("r") def test_finalized(self, setup): assert not setup.H.finalized - setup.H.H[0, 0] = 1. + setup.H.H[0, 0] = 1.0 setup.H.finalize() assert setup.H.finalized assert setup.H.nnz == 1 setup.H.empty() assert not setup.HS.finalized - setup.HS.H[0, 0] = 1. - setup.HS.S[0, 0] = 1. + setup.HS.H[0, 0] = 1.0 + setup.HS.S[0, 0] = 1.0 setup.HS.finalize() assert setup.HS.finalized assert setup.HS.nnz == 1 @@ -1650,7 +1736,7 @@ def test_finalized(self, setup): @pytest.mark.parametrize("ny", [1, 5]) @pytest.mark.parametrize("nz", [1, 6]) def test_tile_same(self, setup, nx, ny, nz): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] # Create reference Hg = Hamiltonian(setup.g.tile(nx, 0).tile(ny, 1).tile(nz, 2)) @@ -1663,15 +1749,17 @@ def test_tile_same(self, setup, nx, ny, nz): H.finalize() Hg.finalize() assert np.allclose(H._csr._D, Hg._csr._D) - assert np.allclose(Hg.Hk([0.1, 0.2, 0.3], format='array'), - H.Hk([0.1, 0.2, 0.3], format='array')) + assert np.allclose( + Hg.Hk([0.1, 0.2, 0.3], format="array"), + H.Hk([0.1, 0.2, 0.3], format="array"), + ) @pytest.mark.slow def test_tile3(self, setup): - R, param = [0.1, 1.1, 2.1, 3.1], [1., 2., 3., 4.] + R, param = [0.1, 1.1, 2.1, 3.1], [1.0, 2.0, 3.0, 4.0] # Create reference - g = Geometry([[0] * 3], Atom('H', R=[4.]), lattice=[1.] * 3) + g = Geometry([[0] * 3], Atom("H", R=[4.0]), lattice=[1.0] * 3) g.set_nsc([7] * 3) # Now create bigger geometry @@ -1695,15 +1783,15 @@ def func(self, ia, atoms, atoms_xyz=None): io = self.geometry.a2o(ia) # Set on-site on first and second orbital odx = self.geometry.a2o(idx[0]) - self[io, odx] = -1. - self[io+1, odx+1] = 1. + self[io, odx] = -1.0 + self[io + 1, odx + 1] = 1.0 # Set connecting odx = self.geometry.a2o(idx[1]) self[io, odx] = 0.2 - self[io, odx+1] = 0.01 - self[io+1, odx] = 0.01 - self[io+1, odx+1] = 0.3 + self[io, odx + 1] = 0.01 + self[io + 1, odx] = 0.01 + self[io + 1, odx + 1] = 0.3 setup.H2.construct(func) Hbig = setup.H2.tile(3, 0).tile(3, 1) @@ -1719,7 +1807,7 @@ def func(self, ia, atoms, atoms_xyz=None): @pytest.mark.slow def test_repeat1(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] # Create reference Hg = Hamiltonian(setup.g.repeat(2, 0)) @@ -1735,7 +1823,7 @@ def test_repeat1(self, setup): @pytest.mark.slow def test_repeat2(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] # Create reference Hg = Hamiltonian(setup.g.repeat(2, 0).repeat(2, 1).repeat(2, 2)) @@ -1743,7 +1831,7 @@ def test_repeat2(self, setup): Hg.finalize() H = Hamiltonian(setup.g) H.construct([R, param]) - H = H.repeat(2, 0).repeat(2, 1). repeat(2, 2) + H = H.repeat(2, 0).repeat(2, 1).repeat(2, 2) assert Hg.spsame(H) H.finalize() Hg.finalize() @@ -1751,10 +1839,10 @@ def test_repeat2(self, setup): @pytest.mark.slow def test_repeat3(self, setup): - R, param = [0.1, 1.1, 2.1, 3.1], [1., 2., 3., 4.] + R, param = [0.1, 1.1, 2.1, 3.1], [1.0, 2.0, 3.0, 4.0] # Create reference - g = Geometry([[0] * 3], Atom('H', R=[4.]), lattice=[1.] * 3) + g = Geometry([[0] * 3], Atom("H", R=[4.0]), lattice=[1.0] * 3) g.set_nsc([7] * 3) # Now create bigger geometry @@ -1779,15 +1867,15 @@ def func(self, ia, atoms, atoms_xyz=None): io = self.geometry.a2o(ia) # Set on-site on first and second orbital odx = self.geometry.a2o(idx[0]) - self[io, odx] = -1. - self[io+1, odx+1] = 1. + self[io, odx] = -1.0 + self[io + 1, odx + 1] = 1.0 # Set connecting odx = self.geometry.a2o(idx[1]) self[io, odx] = 0.2 - self[io, odx+1] = 0.01 - self[io+1, odx] = 0.01 - self[io+1, odx+1] = 0.3 + self[io, odx + 1] = 0.01 + self[io + 1, odx] = 0.01 + self[io + 1, odx + 1] = 0.3 setup.H2.construct(func) Hbig = setup.H2.repeat(3, 0).repeat(3, 1) @@ -1803,7 +1891,7 @@ def func(self, ia, atoms, atoms_xyz=None): setup.H2.empty() def test_sub1(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] # Create reference H = Hamiltonian(setup.g) @@ -1821,7 +1909,7 @@ def test_sub1(self, setup): assert len(Hg) == len(setup.g) def test_set_nsc1(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] # Create reference H = Hamiltonian(setup.g.copy()) @@ -1842,7 +1930,7 @@ def test_set_nsc1(self, setup): assert Hg.spsame(H) def test_shift1(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] H = Hamiltonian(setup.g.copy()) H.construct([R, param]) eig0 = H.eigh()[0] @@ -1850,7 +1938,7 @@ def test_shift1(self, setup): assert H.eigh()[0] == pytest.approx(eig0 + 0.2) def test_shift2(self, setup): - R, param = [0.1, 1.5], [(1., 1.), (0.1, 0.1)] + R, param = [0.1, 1.5], [(1.0, 1.0), (0.1, 0.1)] H = Hamiltonian(setup.g.copy(), orthogonal=False) H.construct([R, param]) eig0 = H.eigh()[0] @@ -1858,8 +1946,8 @@ def test_shift2(self, setup): assert H.eigh()[0] == pytest.approx(eig0 + 0.2) def test_shift3(self, setup): - R, param = [0.1, 1.5], [(1., -1., 1.), (0.1, 0.1, 0.1)] - H = Hamiltonian(setup.g.copy(), spin=Spin('P'), orthogonal=False) + R, param = [0.1, 1.5], [(1.0, -1.0, 1.0), (0.1, 0.1, 0.1)] + H = Hamiltonian(setup.g.copy(), spin=Spin("P"), orthogonal=False) H.construct([R, param]) eig0_0 = H.eigh(spin=0)[0] eig1_0 = H.eigh(spin=1)[0] @@ -1871,36 +1959,36 @@ def test_shift3(self, setup): assert H.eigh(spin=1)[0] == pytest.approx(eig1_0) def test_fermi_level(self, setup): - R, param = [0.1, 1.5], [(1., 1.), (2.1, 0.1)] + R, param = [0.1, 1.5], [(1.0, 1.0), (2.1, 0.1)] H = Hamiltonian(setup.g.copy(), orthogonal=False) H.construct([R, param]) bz = MonkhorstPack(H, [10, 10, 1]) q = 0.9 Ef = H.fermi_level(bz, q=q) H.shift(-Ef) - assert H.fermi_level(bz, q=q) == pytest.approx(0., abs=1e-6) + assert H.fermi_level(bz, q=q) == pytest.approx(0.0, abs=1e-6) def test_fermi_level_spin(self, setup): - R, param = [0.1, 1.5], [(1., 1.), (2.1, 0.1)] - H = Hamiltonian(setup.g.copy(), spin=Spin('P')) + R, param = [0.1, 1.5], [(1.0, 1.0), (2.1, 0.1)] + H = Hamiltonian(setup.g.copy(), spin=Spin("P")) H.construct([R, param]) bz = MonkhorstPack(H, [10, 10, 1]) q = 1.1 Ef = H.fermi_level(bz, q=q) assert np.asarray(Ef).ndim == 0 H.shift(-Ef) - assert H.fermi_level(bz, q=q) == pytest.approx(0., abs=1e-6) + assert H.fermi_level(bz, q=q) == pytest.approx(0.0, abs=1e-6) def test_fermi_level_spin_separate(self, setup): - R, param = [0.1, 1.5], [(1., 1.), (2.1, 0.1)] - H = Hamiltonian(setup.g.copy(), spin=Spin('P')) + R, param = [0.1, 1.5], [(1.0, 1.0), (2.1, 0.1)] + H = Hamiltonian(setup.g.copy(), spin=Spin("P")) H.construct([R, param]) bz = MonkhorstPack(H, [10, 10, 1]) q = [0.5, 0.3] Ef = H.fermi_level(bz, q=q) assert len(Ef) == 2 H.shift(-Ef) - assert np.allclose(H.fermi_level(bz, q=q), 0.) + assert np.allclose(H.fermi_level(bz, q=q), 0.0) def test_wrap_oplist(self, setup): R, param = [0.1, 1.5], [1, 2.1] @@ -1908,25 +1996,31 @@ def test_wrap_oplist(self, setup): H.construct([R, param]) bz = MonkhorstPack(H, [10, 10, 1]) E = np.linspace(-4, 4, 21) - dist = get_distribution('gaussian', smearing=0.05) + dist = get_distribution("gaussian", smearing=0.05) + def wrap(es, parent, k, weight): DOS = es.DOS(E, distribution=dist) PDOS = es.PDOS(E, distribution=dist) vel = es.velocity() * es.occupation() return oplist([DOS, PDOS, vel]) + bz_avg = bz.apply.average results = bz_avg.eigenstate(wrap=wrap) - assert np.allclose(bz_avg.eigenstate(wrap=lambda es: es.DOS(E, distribution=dist)), results[0]) - assert np.allclose(bz_avg.eigenstate(wrap=lambda es: es.PDOS(E, distribution=dist)), results[1]) + assert np.allclose( + bz_avg.eigenstate(wrap=lambda es: es.DOS(E, distribution=dist)), results[0] + ) + assert np.allclose( + bz_avg.eigenstate(wrap=lambda es: es.PDOS(E, distribution=dist)), results[1] + ) def test_edges1(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] H = Hamiltonian(setup.g) H.construct([R, param]) assert len(H.edges(0)) == 4 def test_edges2(self, setup): - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] H = Hamiltonian(setup.g) H.construct([R, param]) with pytest.raises(ValueError): @@ -1938,15 +2032,15 @@ def func(self, ia, atoms, atoms_xyz=None): io = self.geometry.a2o(ia) # Set on-site on first and second orbital odx = self.geometry.a2o(idx[0]) - self[io, odx] = -1. - self[io+1, odx+1] = 1. + self[io, odx] = -1.0 + self[io + 1, odx + 1] = 1.0 # Set connecting odx = self.geometry.a2o(idx[1]) self[io, odx] = 0.2 - self[io, odx+1] = 0.01 - self[io+1, odx] = 0.01 - self[io+1, odx+1] = 0.3 + self[io, odx + 1] = 0.01 + self[io + 1, odx] = 0.01 + self[io + 1, odx + 1] = 0.3 H2 = setup.H2.copy() H2.construct(func) @@ -1975,12 +2069,12 @@ def test_wavefunction1(): o1 = SphericalOrbital(0, (np.linspace(0, 2, N), np.exp(-np.linspace(0, 100, N)))) G = Geometry([[1] * 3, [2] * 3], Atom(6, o1), lattice=[4, 4, 4]) H = Hamiltonian(G) - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] H.construct([R, param]) ES = H.eigenstate(dtype=np.float64) # Plot in the full thing grid = Grid(0.1, geometry=H.geometry) - grid.fill(0.) + grid.fill(0.0) ES.sub(0).wavefunction(grid) @@ -1989,13 +2083,13 @@ def test_wavefunction2(): o1 = SphericalOrbital(0, (np.linspace(0, 2, N), np.exp(-np.linspace(0, 100, N)))) G = Geometry([[1] * 3, [2] * 3], Atom(6, o1), lattice=[4, 4, 4]) H = Hamiltonian(G) - R, param = [0.1, 1.5], [1., 0.1] + R, param = [0.1, 1.5], [1.0, 0.1] H.construct([R, param]) ES = H.eigenstate(dtype=np.float64) # This is effectively plotting outside where no atoms exists # (there could however still be psi weight). grid = Grid(0.1, lattice=Lattice([2, 2, 2], origin=[2] * 3)) - grid.fill(0.) + grid.fill(0.0) ES.sub(0).wavefunction(grid) @@ -2003,14 +2097,13 @@ def test_wavefunction3(): N = 50 o1 = SphericalOrbital(0, (np.linspace(0, 2, N), np.exp(-np.linspace(0, 100, N)))) G = Geometry([[1] * 3, [2] * 3], Atom(6, o1), lattice=[4, 4, 4]) - H = Hamiltonian(G, spin=Spin('nc')) - R, param = [0.1, 1.5], [[0., 0., 0.1, -0.1], - [1., 1., 0.1, -0.1]] + H = Hamiltonian(G, spin=Spin("nc")) + R, param = [0.1, 1.5], [[0.0, 0.0, 0.1, -0.1], [1.0, 1.0, 0.1, -0.1]] H.construct([R, param]) ES = H.eigenstate() # Plot in the full thing grid = Grid(0.1, dtype=np.complex128, lattice=Lattice([2, 2, 2], origin=[-1] * 3)) - grid.fill(0.) + grid.fill(0.0) ES.sub(0).wavefunction(grid) @@ -2018,12 +2111,11 @@ def test_wavefunction_eta(): N = 50 o1 = SphericalOrbital(0, (np.linspace(0, 2, N), np.exp(-np.linspace(0, 100, N)))) G = Geometry([[1] * 3, [2] * 3], Atom(6, o1), lattice=[4, 4, 4]) - H = Hamiltonian(G, spin=Spin('nc')) - R, param = [0.1, 1.5], [[0., 0., 0.1, -0.1], - [1., 1., 0.1, -0.1]] + H = Hamiltonian(G, spin=Spin("nc")) + R, param = [0.1, 1.5], [[0.0, 0.0, 0.1, -0.1], [1.0, 1.0, 0.1, -0.1]] H.construct([R, param]) ES = H.eigenstate() # Plot in the full thing grid = Grid(0.1, dtype=np.complex128, lattice=Lattice([2, 2, 2], origin=[-1] * 3)) - grid.fill(0.) + grid.fill(0.0) ES.sub(0).wavefunction(grid, eta=True) diff --git a/src/sisl/physics/tests/test_overlap.py b/src/sisl/physics/tests/test_overlap.py index 108972518f..7855553d8b 100644 --- a/src/sisl/physics/tests/test_overlap.py +++ b/src/sisl/physics/tests/test_overlap.py @@ -12,29 +12,34 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) n = 60 rf = np.linspace(0, bond * 1.01, n) rf = (rf, rf) - orb = SphericalOrbital(1, rf, 2.) + orb = SphericalOrbital(1, rf, 2.0) C = Atom(6, orb.toAtomicOrbital()) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.S = Overlap(self.g) return t() class TestOverlap: - def test_objects(self, setup): assert len(setup.S.xyz) == 2 assert setup.g.no == len(setup.S) @@ -47,9 +52,9 @@ def test_ortho(self, setup): def test_set1(self, setup): S = setup.S.copy() - S.S[0, 0] = 1. - assert S[0, 0] == 1. - assert S[1, 0] == 0. + S.S[0, 0] = 1.0 + assert S[0, 0] == 1.0 + assert S[1, 0] == 0.0 def test_fromsp(self, setup): S = setup.S.copy() diff --git a/src/sisl/physics/tests/test_physics_sparse.py b/src/sisl/physics/tests/test_physics_sparse.py index 7433b628c0..59b5e22ec2 100644 --- a/src/sisl/physics/tests/test_physics_sparse.py +++ b/src/sisl/physics/tests/test_physics_sparse.py @@ -31,10 +31,10 @@ def test_S(): sp = SparseOrbitalBZ(gr, orthogonal=False) sp[0, 0] = 0.5 sp[1, 1] = 0.5 - sp.S[0, 0] = 1. - sp.S[1, 1] = 1. + sp.S[0, 0] = 1.0 + sp.S[1, 1] = 1.0 assert sp[1, 1, 0] == pytest.approx(0.5) - assert sp.S[1, 1] == pytest.approx(1.) + assert sp.S[1, 1] == pytest.approx(1.0) def test_eigh_orthogonal(): @@ -53,8 +53,8 @@ def test_eigh_non_orthogonal(): sp = SparseOrbitalBZ(gr, orthogonal=False) sp[0, 0] = 0.5 sp[1, 1] = 0.5 - sp.S[0, 0] = 1. - sp.S[1, 1] = 1. + sp.S[0, 0] = 1.0 + sp.S[1, 1] = 1.0 assert np.allclose(sp.eigh(), [0.5, 0.5]) @@ -71,18 +71,19 @@ def test_eigsh_orthogonal(): def test_eigsh_non_orthogonal(): sp = SparseOrbitalBZ(_get(), orthogonal=False) - sp.construct([(0.1, 1.44), ([0, 1.], [-2.7, 0])]) + sp.construct([(0.1, 1.44), ([0, 1.0], [-2.7, 0])]) sp.eigsh(n=1) def test_pickle_non_orthogonal(): import pickle as p + gr = _get() sp = SparseOrbitalBZ(gr, orthogonal=False) sp[0, 0] = 0.5 sp[1, 1] = 0.5 - sp.S[0, 0] = 1. - sp.S[1, 1] = 1. + sp.S[0, 0] = 1.0 + sp.S[1, 1] = 1.0 s = p.dumps(sp) SP = p.loads(s) assert sp.spsame(SP) @@ -91,12 +92,13 @@ def test_pickle_non_orthogonal(): def test_pickle_non_orthogonal_spin(): import pickle as p + gr = _get() - sp = SparseOrbitalBZSpin(gr, spin=Spin('p'), orthogonal=False) + sp = SparseOrbitalBZSpin(gr, spin=Spin("p"), orthogonal=False) sp[0, 0, :] = 0.5 sp[1, 1, :] = 0.5 - sp.S[0, 0] = 1. - sp.S[1, 1] = 1. + sp.S[0, 0] = 1.0 + sp.S[1, 1] = 1.0 s = p.dumps(sp) SP = p.loads(s) assert sp.spsame(SP) @@ -107,7 +109,7 @@ def test_pickle_non_orthogonal_spin(): @pytest.mark.parametrize("n1", [1, 3]) @pytest.mark.parametrize("n2", [1, 4]) def test_sparse_orbital_bz_hermitian(n0, n1, n2): - g = geom.fcc(1., Atom(1, R=1.5)) * 2 + g = geom.fcc(1.0, Atom(1, R=1.5)) * 2 s = SparseOrbitalBZ(g) s.construct([[0.1, 1.51], [1, 2]]) s = s.tile(n0, 0).tile(n1, 1).tile(n2, 2) @@ -118,7 +120,7 @@ def test_sparse_orbital_bz_hermitian(n0, n1, n2): # orbitals connecting to io edges = s.edges(io) # Figure out the transposed supercell indices of the edges - isc = - s.geometry.o2isc(edges) + isc = -s.geometry.o2isc(edges) # Convert to supercell IO = s.geometry.lattice.sc_index(isc) * no + io # Figure out if 'io' is also in the back-edges @@ -130,31 +132,29 @@ def test_sparse_orbital_bz_hermitian(n0, n1, n2): assert s.nnz == nnz # Since we are also dealing with f32 data-types we cannot go beyond 1e-7 - approx_zero = pytest.approx(0., abs=1e-5) + approx_zero = pytest.approx(0.0, abs=1e-5) for k0 in [0, 0.1]: for k1 in [0, -0.15]: for k2 in [0, 0.33333]: k = (k0, k1, k2) - if np.allclose(k, 0.): + if np.allclose(k, 0.0): dtypes = [None, np.float32, np.float64] else: dtypes = [None, np.complex64, np.complex128] # Also assert Pk == Pk.H for all data-types for dtype in dtypes: - Pk = s.Pk(k=k, format='csr', dtype=dtype) + Pk = s.Pk(k=k, format="csr", dtype=dtype) assert abs(Pk - Pk.getH()).toarray().max() == approx_zero - Pk = s.Pk(k=k, format='array', dtype=dtype) + Pk = s.Pk(k=k, format="array", dtype=dtype) assert np.abs(Pk - np.conj(Pk.T)).max() == approx_zero def test_sparse_orbital_bz_non_colinear(): - M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin('NC')) - M.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4], - [0.2, 0.3, 0.4, 0.5]])) + M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin("NC")) + M.construct(([0.1, 1.44], [[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5]])) M.finalize() MT = M.transpose() @@ -170,11 +170,9 @@ def test_sparse_orbital_bz_non_colinear(): def test_sparse_orbital_bz_non_colinear_trs_kramers_theorem(): - M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin('NC')) + M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin("NC")) - M.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4], - [0.2, 0.3, 0.4, 0.5]])) + M.construct(([0.1, 1.44], [[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5]])) M.finalize() M = (M + M.transpose(True)) * 0.5 @@ -189,20 +187,32 @@ def test_sparse_orbital_bz_non_colinear_trs_kramers_theorem(): def test_sparse_orbital_bz_spin_orbit_warns_hermitian(): - M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin('SO')) + M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin("SO")) - with pytest.warns(SislWarning, match='Hermitian'): - M.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) + with pytest.warns(SislWarning, match="Hermitian"): + M.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) def test_sparse_orbital_bz_spin_orbit(): - M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin('SO')) - - M.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.0, 0.0, 0.3, -0.4], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) + M = SparseOrbitalBZSpin(geom.graphene(), spin=Spin("SO")) + + M.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.0, 0.0, 0.3, -0.4], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) M.finalize() MT = M.transpose() @@ -215,11 +225,17 @@ def test_sparse_orbital_bz_spin_orbit(): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_sparse_orbital_bz_spin_orbit_trs_kramers_theorem(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='SO') - - M.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) + M = SparseOrbitalBZSpin(geom.graphene(), spin="SO") + + M.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) M.finalize() M = (M + M.transpose(True)) / 2 @@ -236,37 +252,43 @@ def test_sparse_orbital_bz_spin_orbit_trs_kramers_theorem(): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") @pytest.mark.xfail(reason="Construct does not impose hermitian property") def test_sparse_orbital_bz_spin_orbit_hermitian_not(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='SO') - - M.construct(([0.1, 1.44], - [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]])) + M = SparseOrbitalBZSpin(geom.graphene(), spin="SO") + + M.construct( + ( + [0.1, 1.44], + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + ], + ) + ) M.finalize() new = (M + M.transpose(True)) / 2 assert np.abs((M - new)._csr._D).sum() == 0 def test_sparse_orbital_transform_ortho_unpolarized(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='unpolarized') + M = SparseOrbitalBZSpin(geom.graphene(), spin="unpolarized") a = np.arange(M.spin.size) + 0.3 M.construct(([0.1, 1.44], [a, a + 0.1])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='unpolarized') + Mt = M.transform(spin="unpolarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 - Mt = M.transform(spin='polarized') + Mt = M.transform(spin="polarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[0] - Mt.tocsr(1)).sum() == 0 - Mt = M.transform(spin='non-colinear') + Mt = M.transform(spin="non-colinear") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[0] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 assert np.abs(Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='so') + Mt = M.transform(spin="so") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[0] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 @@ -274,28 +296,28 @@ def test_sparse_orbital_transform_ortho_unpolarized(): def test_sparse_orbital_transform_nonortho_unpolarized(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='unpolarized', orthogonal=False) + M = SparseOrbitalBZSpin(geom.graphene(), spin="unpolarized", orthogonal=False) a = np.arange(M.spin.size + 1) + 0.3 M.construct(([0.1, 1.44], [a, a + 0.1])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='unpolarized') + Mt = M.transform(spin="unpolarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='polarized') + Mt = M.transform(spin="polarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[0] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='non-colinear') + Mt = M.transform(spin="non-colinear") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[0] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='so') + Mt = M.transform(spin="so") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[0] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 @@ -303,26 +325,26 @@ def test_sparse_orbital_transform_nonortho_unpolarized(): def test_sparse_orbital_transform_ortho_polarized(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='polarized') + M = SparseOrbitalBZSpin(geom.graphene(), spin="polarized") a = np.arange(M.spin.size) + 0.3 M.construct(([0.1, 1.44], [a, a + 0.1])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='unpolarized') + Mt = M.transform(spin="unpolarized") assert np.abs(0.5 * Mcsr[0] + 0.5 * Mcsr[1] - Mt.tocsr(0)).sum() == 0 - Mt = M.transform(spin='polarized') + Mt = M.transform(spin="polarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 - Mt = M.transform(spin='non-colinear') + Mt = M.transform(spin="non-colinear") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 assert np.abs(Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='so') + Mt = M.transform(spin="so") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 @@ -330,26 +352,26 @@ def test_sparse_orbital_transform_ortho_polarized(): def test_sparse_orbital_transform_ortho_nc(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='non-colinear') + M = SparseOrbitalBZSpin(geom.graphene(), spin="non-colinear") a = np.arange(M.spin.size) + 0.3 M.construct(([0.1, 1.44], [a, a + 0.1])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='unpolarized') + Mt = M.transform(spin="unpolarized") assert np.abs(0.5 * Mcsr[0] + 0.5 * Mcsr[1] - Mt.tocsr(0)).sum() == 0 - Mt = M.transform(spin='polarized') + Mt = M.transform(spin="polarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 - Mt = M.transform(spin='non-colinear') + Mt = M.transform(spin="non-colinear") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[2] - Mt.tocsr(2)).sum() == 0 assert np.abs(Mcsr[3] - Mt.tocsr(3)).sum() == 0 - Mt = M.transform(spin='so') + Mt = M.transform(spin="so") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[2] - Mt.tocsr(2)).sum() == 0 @@ -358,26 +380,26 @@ def test_sparse_orbital_transform_ortho_nc(): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_sparse_orbital_transform_ortho_so(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='so') + M = SparseOrbitalBZSpin(geom.graphene(), spin="so") a = np.arange(M.spin.size) + 0.3 M.construct(([0.1, 1.44], [a, a + 0.1])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='unpolarized') + Mt = M.transform(spin="unpolarized") assert np.abs(0.5 * Mcsr[0] + 0.5 * Mcsr[1] - Mt.tocsr(0)).sum() == 0 - Mt = M.transform(spin='polarized') + Mt = M.transform(spin="polarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 - Mt = M.transform(spin='non-colinear') + Mt = M.transform(spin="non-colinear") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[2] - Mt.tocsr(2)).sum() == 0 assert np.abs(Mcsr[3] - Mt.tocsr(3)).sum() == 0 - Mt = M.transform(spin='so') + Mt = M.transform(spin="so") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[2] - Mt.tocsr(2)).sum() == 0 @@ -386,29 +408,29 @@ def test_sparse_orbital_transform_ortho_so(): @pytest.mark.filterwarnings("ignore", message="*is NOT Hermitian for on-site") def test_sparse_orbital_transform_nonortho_so(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='so', orthogonal=False) + M = SparseOrbitalBZSpin(geom.graphene(), spin="so", orthogonal=False) a = np.arange(M.spin.size + 1) + 0.3 M.construct(([0.1, 1.44], [a, a + 0.1])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='unpolarized') + Mt = M.transform(spin="unpolarized") assert np.abs(0.5 * Mcsr[0] + 0.5 * Mcsr[1] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='polarized') + Mt = M.transform(spin="polarized") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='non-colinear') + Mt = M.transform(spin="non-colinear") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[2] - Mt.tocsr(2)).sum() == 0 assert np.abs(Mcsr[3] - Mt.tocsr(3)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='so') + Mt = M.transform(spin="so") assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[2] - Mt.tocsr(2)).sum() == 0 @@ -417,8 +439,8 @@ def test_sparse_orbital_transform_nonortho_so(): def test_sparse_orbital_transform_basis(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='polarized', orthogonal=False) - M.construct(([0.1, 1.44], [(3., 2., 1.), (0.3, 0.2, 0.)])) + M = SparseOrbitalBZSpin(geom.graphene(), spin="polarized", orthogonal=False) + M.construct(([0.1, 1.44], [(3.0, 2.0, 1.0), (0.3, 0.2, 0.0)])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] @@ -429,46 +451,60 @@ def test_sparse_orbital_transform_basis(): assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 -@pytest.mark.xfail(sys.platform.startswith("win"), reason="Data type cannot be float128") +@pytest.mark.xfail( + sys.platform.startswith("win"), reason="Data type cannot be float128" +) def test_sparse_orbital_transform_combinations(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='polarized', orthogonal=False, dtype=np.int32) + M = SparseOrbitalBZSpin( + geom.graphene(), spin="polarized", orthogonal=False, dtype=np.int32 + ) M.construct(([0.1, 1.44], [(3, 2, 1), (2, 1, 0)])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='non-colinear', dtype=np.float64, orthogonal=True).transform(spin='polarized', orthogonal=False) + Mt = M.transform(spin="non-colinear", dtype=np.float64, orthogonal=True).transform( + spin="polarized", orthogonal=False + ) assert M.dim == Mt.dim assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(dtype=np.float128, orthogonal=True).transform(spin='so', dtype=np.float64, orthogonal=False) + Mt = M.transform(dtype=np.float128, orthogonal=True).transform( + spin="so", dtype=np.float64, orthogonal=False + ) assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='polarized', orthogonal=True).transform(spin='so', dtype=np.float64, orthogonal=False) + Mt = M.transform(spin="polarized", orthogonal=True).transform( + spin="so", dtype=np.float64, orthogonal=False + ) assert np.abs(Mcsr[0] - Mt.tocsr(0)).sum() == 0 assert np.abs(Mcsr[1] - Mt.tocsr(1)).sum() == 0 assert np.abs(Mt.tocsr(2)).sum() == 0 assert np.abs(Mcsr[-1] - Mt.tocsr(-1)).sum() == 0 - Mt = M.transform(spin='unpolarized', dtype=np.float32, orthogonal=True).transform(dtype=np.complex128, orthogonal=False) + Mt = M.transform(spin="unpolarized", dtype=np.float32, orthogonal=True).transform( + dtype=np.complex128, orthogonal=False + ) assert np.abs(0.5 * Mcsr[0] + 0.5 * Mcsr[1] - Mt.tocsr(0)).sum() == 0 def test_sparse_orbital_transform_matrix(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='polarized', orthogonal=False, dtype=np.int32) + M = SparseOrbitalBZSpin( + geom.graphene(), spin="polarized", orthogonal=False, dtype=np.int32 + ) M.construct(([0.1, 1.44], [(1, 2, 3), (4, 5, 6)])) M.finalize() Mcsr = [M.tocsr(i) for i in range(M.shape[2])] - Mt = M.transform(spin='unpolarized', matrix=np.ones((1, 3)), orthogonal=True) + Mt = M.transform(spin="unpolarized", matrix=np.ones((1, 3)), orthogonal=True) assert Mt.dim == 1 assert np.abs(Mcsr[0] + Mcsr[1] + Mcsr[2] - Mt.tocsr(0)).sum() == 0 - Mt = M.transform(spin='polarized', matrix=np.ones((2, 3)), orthogonal=True) + Mt = M.transform(spin="polarized", matrix=np.ones((2, 3)), orthogonal=True) assert Mt.dim == 2 assert np.abs(Mcsr[0] + Mcsr[1] + Mcsr[2] - Mt.tocsr(1)).sum() == 0 @@ -476,30 +512,36 @@ def test_sparse_orbital_transform_matrix(): assert Mt.dim == 3 assert np.abs(Mcsr[0] + Mcsr[1] + Mcsr[2] - Mt.tocsr(2)).sum() == 0 - Mt = M.transform(spin='non-colinear', matrix=np.ones((4, 3)), orthogonal=True, dtype=np.float64) + Mt = M.transform( + spin="non-colinear", matrix=np.ones((4, 3)), orthogonal=True, dtype=np.float64 + ) assert Mt.dim == 4 assert np.abs(Mcsr[0] + Mcsr[1] + Mcsr[2] - Mt.tocsr(3)).sum() == 0 - Mt = M.transform(spin='non-colinear', matrix=np.ones((5, 3)), dtype=np.float64) + Mt = M.transform(spin="non-colinear", matrix=np.ones((5, 3)), dtype=np.float64) assert Mt.dim == 5 assert np.abs(Mcsr[0] + Mcsr[1] + Mcsr[2] - Mt.tocsr(4)).sum() == 0 - Mt = M.transform(spin='so', matrix=np.ones((8, 3)), orthogonal=True, dtype=np.float64) + Mt = M.transform( + spin="so", matrix=np.ones((8, 3)), orthogonal=True, dtype=np.float64 + ) assert Mt.dim == 8 assert np.abs(Mcsr[0] + Mcsr[1] + Mcsr[2] - Mt.tocsr(7)).sum() == 0 - Mt = M.transform(spin='so', matrix=np.ones((9, 3)), dtype=np.float64) + Mt = M.transform(spin="so", matrix=np.ones((9, 3)), dtype=np.float64) assert Mt.dim == 9 assert np.abs(Mcsr[0] + Mcsr[1] + Mcsr[2] - Mt.tocsr(8)).sum() == 0 def test_sparse_orbital_transform_fail(): - M = SparseOrbitalBZSpin(geom.graphene(), spin='polarized', orthogonal=False, dtype=np.int32) + M = SparseOrbitalBZSpin( + geom.graphene(), spin="polarized", orthogonal=False, dtype=np.int32 + ) M.construct(([0.1, 1.44], [(1, 2, 3), (4, 5, 6)])) M.finalize() with pytest.raises(ValueError): - M.transform(np.zeros([2, 2]), spin='unpolarized') + M.transform(np.zeros([2, 2]), spin="unpolarized") @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128]) diff --git a/src/sisl/physics/tests/test_self_energy.py b/src/sisl/physics/tests/test_self_energy.py index ac4b5c8611..21ff805686 100644 --- a/src/sisl/physics/tests/test_self_energy.py +++ b/src/sisl/physics/tests/test_self_energy.py @@ -24,38 +24,47 @@ WideBandSE, ) -pytestmark = [pytest.mark.physics, pytest.mark.self_energy, - pytest.mark.filterwarnings("ignore", category=SparseEfficiencyWarning)] +pytestmark = [ + pytest.mark.physics, + pytest.mark.self_energy, + pytest.mark.filterwarnings("ignore", category=SparseEfficiencyWarning), +] @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + C = Atom(Z=6, R=[bond * 1.01]) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.H = Hamiltonian(self.g) - func = self.H.create_construct([0.1, bond+0.1], [0., -2.7]) + func = self.H.create_construct([0.1, bond + 0.1], [0.0, -2.7]) self.H.construct(func) self.HS = Hamiltonian(self.g, orthogonal=False) - func = self.HS.create_construct([0.1, bond+0.1], [(0., 1.), (-2.7, 0.)]) + func = self.HS.create_construct( + [0.1, bond + 0.1], [(0.0, 1.0), (-2.7, 0.0)] + ) self.HS.construct(func) + return t() def test_objects(setup): - for D, si, sid in [('+A', 0, 1), - ('-A', 0, -1), - ('+B', 1, 1), - ('-B', 1, -1)]: + for D, si, sid in [("+A", 0, 1), ("-A", 0, -1), ("+B", 1, 1), ("-B", 1, -1)]: SE = SemiInfinite(setup.H.copy(), D) assert SE.semi_inf == si assert SE.semi_inf_dir == sid @@ -63,7 +72,7 @@ def test_objects(setup): def test_sancho_orthogonal(setup): - SE = RecursiveSI(setup.H.copy(), '+A') + SE = RecursiveSI(setup.H.copy(), "+A") assert not np.allclose(SE.self_energy(0.1), SE.self_energy(0.1, bulk=True)) @@ -74,22 +83,22 @@ def test_sancho_bloch_zero_off_diag(setup): # disconnect transverse directions H.set_nsc(b=1) no = len(H) - SE = RecursiveSI(H, '+A') + SE = RecursiveSI(H, "+A") for nb in [2, 4, 5]: bloch = Bloch(1, nb, 1) for E in [-0.1, 0.2, 0.4, -0.5]: se = bloch(SE.self_energy, [0.2, 0.2, 0.2], E=E) for b in range(1, nb): - off = b*no - assert np.allclose(se[:no, :no], se[off:off+no, off:off+no]) - se[off:off+no, off:off+no] = 0. - se[:no, :no] = 0. - assert np.allclose(se, 0.) + off = b * no + assert np.allclose(se[:no, :no], se[off : off + no, off : off + no]) + se[off : off + no, off : off + no] = 0.0 + se[:no, :no] = 0.0 + assert np.allclose(se, 0.0) def test_sancho_orthogonal_dtype(setup): - SE = RecursiveSI(setup.H, '+A') + SE = RecursiveSI(setup.H, "+A") s64 = SE.self_energy(0.1, dtype=np.complex64) s128 = SE.self_energy(0.1) assert s64.dtype == np.complex64 @@ -100,28 +109,27 @@ def test_sancho_orthogonal_dtype(setup): def test_sancho_warning(): lattice = Lattice([1, 1, 10], nsc=[5, 5, 1]) C = Atom(Z=6, R=[2 * 1.01]) - g = Geometry([[0., 0., 0.]], - atoms=C, lattice=lattice) + g = Geometry([[0.0, 0.0, 0.0]], atoms=C, lattice=lattice) H = Hamiltonian(g) - func = H.create_construct([0.1, 1.01, 2.01], [0., -2., -1.]) + func = H.create_construct([0.1, 1.01, 2.01], [0.0, -2.0, -1.0]) H.construct(func) with pytest.warns(sisl.SislWarning, match=r"first neighbouring cell.*\[1.\]"): - RecursiveSI(H, '+A') + RecursiveSI(H, "+A") def test_sancho_non_orthogonal(setup): - SE = RecursiveSI(setup.HS, '-A') + SE = RecursiveSI(setup.HS, "-A") assert not np.allclose(SE.self_energy(0.1), SE.self_energy(0.1, bulk=True)) def test_sancho_broadening_matrix(setup): - SE = RecursiveSI(setup.HS, '-A') + SE = RecursiveSI(setup.HS, "-A") assert np.allclose(SE.broadening_matrix(0.1), SE.se2broadening(SE.self_energy(0.1))) def test_sancho_non_orthogonal_dtype(setup): - SE = RecursiveSI(setup.HS, '-A') + SE = RecursiveSI(setup.HS, "-A") s64 = SE.self_energy(0.1, dtype=np.complex64) s128 = SE.self_energy(0.1) assert s64.dtype == np.complex64 @@ -130,8 +138,8 @@ def test_sancho_non_orthogonal_dtype(setup): def test_sancho_lr(setup): - SL = RecursiveSI(setup.HS, '-A') - SR = RecursiveSI(setup.HS, '+A') + SL = RecursiveSI(setup.HS, "-A") + SR = RecursiveSI(setup.HS, "+A") E = 0.1 k = [0, 0.13, 0] @@ -152,7 +160,7 @@ def test_sancho_lr(setup): LB_SEL, LB_SER = SL.self_energy_lr(E, k, bulk=True) L_SE = SL.self_energy(E, k, bulk=True) - R_SE = SR.self_energy(E, k, bulk=True) + R_SE = SR.self_energy(E, k, bulk=True) assert not np.allclose(L_SE, R_SE) RB_SEL, RB_SER = SR.self_energy_lr(E, k, bulk=True) @@ -163,8 +171,8 @@ def test_sancho_lr(setup): def test_sancho_green(setup): - SL = RecursiveSI(setup.HS, '-A') - SR = RecursiveSI(setup.HS, '+A') + SL = RecursiveSI(setup.HS, "-A") + SR = RecursiveSI(setup.HS, "+A") E = 0.1 k = [0, 0.13, 0] @@ -182,12 +190,13 @@ def test_sancho_green(setup): def test_wideband_1(setup): SE = WideBandSE(10, 1e-2) assert SE.self_energy().shape == (10, 10) - assert np.allclose(np.diag(SE.self_energy()), -1j*1e-2) - assert np.allclose(np.diag(SE.self_energy(eta=1)), -1j*1.) - assert np.allclose(np.diag(SE.broadening_matrix(eta=1)), 2.) + assert np.allclose(np.diag(SE.self_energy()), -1j * 1e-2) + assert np.allclose(np.diag(SE.self_energy(eta=1)), -1j * 1.0) + assert np.allclose(np.diag(SE.broadening_matrix(eta=1)), 2.0) # ensure our custom function works! - assert np.allclose(SE.broadening_matrix(eta=1), - SE.se2broadening(SE.self_energy(eta=1))) + assert np.allclose( + SE.broadening_matrix(eta=1), SE.se2broadening(SE.self_energy(eta=1)) + ) @pytest.mark.parametrize("k_axes", [0, 1]) @@ -212,13 +221,15 @@ def test_real_space_HS(setup, k_axes, semi_axis, trs, bz, unfold): def test_real_space_H(setup, k_axes, semi_axis, trs, bz, unfold): if k_axes == semi_axis: return - RSE = RealSpaceSE(setup.H, semi_axis, k_axes, (unfold, unfold, 1), trs=trs, dk=100, bz=bz) + RSE = RealSpaceSE( + setup.H, semi_axis, k_axes, (unfold, unfold, 1), trs=trs, dk=100, bz=bz + ) RSE.green(0.1) RSE.self_energy(0.1) def test_real_space_H_3d(): - lattice = Lattice(1., nsc=[3] * 3) + lattice = Lattice(1.0, nsc=[3] * 3) H = Atom(Z=1, R=[1.001]) geom = Geometry([0] * 3, atoms=H, lattice=lattice) H = Hamiltonian(geom) @@ -235,7 +246,9 @@ def test_real_space_H_3d(): # Since there is only 2 repetitions along one direction we will have the full matrix # coupled! assert np.allclose(RSE.self_energy(0.1), RSE.self_energy(0.1, coupling=True)) - assert np.allclose(RSE.self_energy(0.1, bulk=True), RSE.self_energy(0.1, bulk=True, coupling=True)) + assert np.allclose( + RSE.self_energy(0.1, bulk=True), RSE.self_energy(0.1, bulk=True, coupling=True) + ) def test_real_space_H_dtype(setup): @@ -244,7 +257,7 @@ def test_real_space_H_dtype(setup): g128 = RSE.green(0.1, dtype=np.complex128) assert g64.dtype == np.complex64 assert g128.dtype == np.complex128 - assert np.allclose(g64, g128, atol=1.e-4) + assert np.allclose(g64, g128, atol=1.0e-4) s64 = RSE.self_energy(0.1, dtype=np.complex64) s128 = RSE.self_energy(0.1, dtype=np.complex128) @@ -376,7 +389,7 @@ def test_real_space_SE_spin_orbit(): @pytest.mark.parametrize("bulk", [True, False]) @pytest.mark.parametrize("coupling", [True, False]) def test_real_space_SI_HS(setup, k_axes, trs, bz, unfold, bulk, coupling): - semi = RecursiveSI(setup.HS, '-B') + semi = RecursiveSI(setup.HS, "-B") surf = setup.HS.tile(4, 1) surf.set_nsc(b=1) RSI = RealSpaceSI(semi, surf, k_axes, (unfold, 1, unfold)) @@ -385,7 +398,7 @@ def test_real_space_SI_HS(setup, k_axes, trs, bz, unfold, bulk, coupling): RSI.self_energy(0.1, bulk=bulk, coupling=coupling) -@pytest.mark.parametrize("semi_dir", ['-B', '+B']) +@pytest.mark.parametrize("semi_dir", ["-B", "+B"]) @pytest.mark.parametrize("k_axes", [0]) @pytest.mark.parametrize("trs", [True, False]) @pytest.mark.parametrize("bz", [None, BrillouinZone([1])]) @@ -393,7 +406,9 @@ def test_real_space_SI_HS(setup, k_axes, trs, bz, unfold, bulk, coupling): @pytest.mark.parametrize("bulk", [True, False]) @pytest.mark.parametrize("semi_bulk", [True, False]) @pytest.mark.parametrize("coupling", [True, False]) -def test_real_space_SI_H(setup, semi_dir, k_axes, trs, bz, unfold, bulk, semi_bulk, coupling): +def test_real_space_SI_H( + setup, semi_dir, k_axes, trs, bz, unfold, bulk, semi_bulk, coupling +): semi = RecursiveSI(setup.H, semi_dir) surf = setup.H.tile(4, 1) surf.set_nsc(b=1) @@ -404,7 +419,7 @@ def test_real_space_SI_H(setup, semi_dir, k_axes, trs, bz, unfold, bulk, semi_bu def test_real_space_SI_H_test(setup): - semi = RecursiveSI(setup.H, '-B') + semi = RecursiveSI(setup.H, "-B") surf = setup.H.tile(4, 1) surf.set_nsc(b=1) RSI = RealSpaceSI(semi, surf, 0, (3, 1, 3)) @@ -416,7 +431,7 @@ def test_real_space_SI_H_test(setup): def test_real_space_SI_H_k_trs(setup): - semi = RecursiveSI(setup.H, '-B') + semi = RecursiveSI(setup.H, "-B") surf = setup.H.tile(4, 1) surf.set_nsc(b=1) RSI = RealSpaceSI(semi, surf, 0, (3, 1, 3)) @@ -427,7 +442,7 @@ def test_real_space_SI_H_k_trs(setup): def test_real_space_SI_fail_semi_in_k(setup): - semi = RecursiveSI(setup.H, '-B') + semi = RecursiveSI(setup.H, "-B") surf = setup.H.tile(4, 1) surf.set_nsc(b=1) with pytest.raises(ValueError): @@ -435,14 +450,14 @@ def test_real_space_SI_fail_semi_in_k(setup): def test_real_space_SI_fail_surf_nsc(setup): - semi = RecursiveSI(setup.H, '-B') + semi = RecursiveSI(setup.H, "-B") surf = setup.H.tile(4, 1) with pytest.raises(ValueError): RSI = RealSpaceSI(semi, surf, 0, (2, 1, 1)) def test_real_space_SI_fail_k_no_nsc(setup): - semi = RecursiveSI(setup.H, '-B') + semi = RecursiveSI(setup.H, "-B") surf = setup.H.tile(4, 1) surf.set_nsc([1] * 3) with pytest.raises(ValueError): @@ -450,7 +465,7 @@ def test_real_space_SI_fail_k_no_nsc(setup): def test_real_space_SI_fail_unfold_in_semi(setup): - semi = RecursiveSI(setup.H, '-B') + semi = RecursiveSI(setup.H, "-B") surf = setup.H.tile(4, 1) surf.set_nsc(b=1) with pytest.raises(ValueError): @@ -464,10 +479,9 @@ def test_real_space_SI_spin_orbit(): on = [4, 0, 0, 0, 0, 0, 0, 0] off = [-1, 0, 0, 0, 0, 0, 0, 0] H.construct([(0.1, 1.1), (on, off)]) - semi = RecursiveSI(H, '-B') + semi = RecursiveSI(H, "-B") surf = H.tile(4, 1) surf.set_nsc(b=1) RS = RealSpaceSI(semi, surf, 0, (2, 1, 1), dk=1.5) RS.self_energy(0.1) RS.self_energy(0.1, coupling=True) - diff --git a/src/sisl/physics/tests/test_spin.py b/src/sisl/physics/tests/test_spin.py index d9ad1b2a5f..4f800ec05e 100644 --- a/src/sisl/physics/tests/test_spin.py +++ b/src/sisl/physics/tests/test_spin.py @@ -12,10 +12,20 @@ def test_spin1(): - for val in ['unpolarized', '', Spin.UNPOLARIZED, - 'polarized', 'p', Spin.POLARIZED, - 'non-collinear', 'nc', Spin.NONCOLINEAR, - 'spin-orbit', 'so', Spin.SPINORBIT]: + for val in [ + "unpolarized", + "", + Spin.UNPOLARIZED, + "polarized", + "p", + Spin.POLARIZED, + "non-collinear", + "nc", + Spin.NONCOLINEAR, + "spin-orbit", + "so", + Spin.SPINORBIT, + ]: s = Spin(val) str(s) s1 = s.copy() @@ -24,9 +34,9 @@ def test_spin1(): def test_spin2(): s1 = Spin() - s2 = Spin('p') - s3 = Spin('nc') - s4 = Spin('so') + s2 = Spin("p") + s3 = Spin("nc") + s4 = Spin("so") assert s1.kind == Spin.UNPOLARIZED assert s2.kind == Spin.POLARIZED @@ -77,7 +87,7 @@ def test_spin2(): def test_spin3(): with pytest.raises(ValueError): - s = Spin('satoehus') + s = Spin("satoehus") def test_spin4(): @@ -168,31 +178,33 @@ def test_pauli(): S = Spin() # Create a fictituous wave-function - sq2 = 2 ** .5 - W = np.array([ - [1/sq2, 1/sq2], # M_x = 1 - [1/sq2, -1/sq2], # M_x = -1 - [0.5 + 0.5j, 0.5 + 0.5j], # M_x = 1 - [0.5 - 0.5j, -0.5 + 0.5j], # M_x = -1 - [1/sq2, 1j/sq2], # M_y = 1 - [1/sq2, -1j/sq2], # M_y = -1 - [0.5 - 0.5j, 0.5 + 0.5j], # M_y = 1 - [0.5 + 0.5j, 0.5 - 0.5j], # M_y = -1 - [1, 0], # M_z = 1 - [0, 1], # M_z = -1 - ]) + sq2 = 2**0.5 + W = np.array( + [ + [1 / sq2, 1 / sq2], # M_x = 1 + [1 / sq2, -1 / sq2], # M_x = -1 + [0.5 + 0.5j, 0.5 + 0.5j], # M_x = 1 + [0.5 - 0.5j, -0.5 + 0.5j], # M_x = -1 + [1 / sq2, 1j / sq2], # M_y = 1 + [1 / sq2, -1j / sq2], # M_y = -1 + [0.5 - 0.5j, 0.5 + 0.5j], # M_y = 1 + [0.5 + 0.5j, 0.5 - 0.5j], # M_y = -1 + [1, 0], # M_z = 1 + [0, 1], # M_z = -1 + ] + ) x = np.array([1, -1, 1, -1, 0, 0, 0, 0, 0, 0]) - assert np.allclose(x, (np.conj(W)*S.X.dot(W.T).T).sum(1).real) + assert np.allclose(x, (np.conj(W) * S.X.dot(W.T).T).sum(1).real) y = np.array([0, 0, 0, 0, 1, -1, 1, -1, 0, 0]) - assert np.allclose(y, (np.conj(W)*np.dot(S.Y, W.T).T).sum(1).real) + assert np.allclose(y, (np.conj(W) * np.dot(S.Y, W.T).T).sum(1).real) z = np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, -1]) - assert np.allclose(z, (np.conj(W)*np.dot(S.Z, W.T).T).sum(1).real) + assert np.allclose(z, (np.conj(W) * np.dot(S.Z, W.T).T).sum(1).real) def test_pickle(): import pickle as p - S = Spin('nc') + S = Spin("nc") n = p.dumps(S) s = p.loads(n) assert S == s diff --git a/src/sisl/physics/tests/test_state.py b/src/sisl/physics/tests/test_state.py index 6b29fe9581..5091d22423 100644 --- a/src/sisl/physics/tests/test_state.py +++ b/src/sisl/physics/tests/test_state.py @@ -19,6 +19,7 @@ def ortho_matrix(n, m=None): m = n max_nm = max(n, m) from scipy.linalg import qr + H = np.random.randn(max_nm, max_nm) + 1j * np.random.randn(max_nm, max_nm) Q, _ = qr(H) M = Q.dot(np.conjugate(Q.T)) @@ -37,9 +38,9 @@ def test_coefficient_creation_simple(): c = Coefficient(ar(6)) str(c) assert len(c) == 6 - assert c.shape == (6, ) + assert c.shape == (6,) assert c.dtype == np.float64 - assert c.dkind == 'f' + assert c.dkind == "f" assert len(c.sub(1)) == 1 assert np.allclose(c.sub(1).c, 1) assert len(c.sub([1, 4])) == 2 @@ -48,16 +49,16 @@ def test_coefficient_creation_simple(): def test_coefficient_creation_info(): - c = Coefficient(ar(6), geom.graphene(), k='HELLO') + c = Coefficient(ar(6), geom.graphene(), k="HELLO") assert np.allclose(c.parent.xyz, geom.graphene().xyz) - assert c.info['k'] == 'HELLO' + assert c.info["k"] == "HELLO" def test_coefficient_copy(): - c = Coefficient(ar(6), geom.graphene(), k='HELLO', test='test') + c = Coefficient(ar(6), geom.graphene(), k="HELLO", test="test") cc = c.copy() - assert cc.info['k'] == 'HELLO' - assert cc.info['test'] == 'test' + assert cc.info["k"] == "HELLO" + assert cc.info["test"] == "test" def test_coefficient_sub(): @@ -69,8 +70,9 @@ def test_coefficient_sub(): for i, sub in enumerate(state): assert len(sub) == 1 - assert np.allclose(state.sub(np.array([False, True, False, True])).c, - state.sub([1, 3]).c) + assert np.allclose( + state.sub(np.array([False, True, False, True])).c, state.sub([1, 3]).c + ) sub = state.sub(np.array([False, True, False, True])) state.sub([1, 3], inplace=True) @@ -107,9 +109,9 @@ def test_state_repr(): def test_state_dkind(): state = State(ar(6)) - assert state.dkind == 'f' + assert state.dkind == "f" state = State(ar(6).astype(np.complex128)) - assert state.dkind == 'c' + assert state.dkind == "c" def test_state_norm(): @@ -134,10 +136,11 @@ def test_state_sub(): assert sub.norm()[0] == norm[i] for i, sub in enumerate(state.iter(True)): - assert (sub ** 2).sum() == norm2[i] + assert (sub**2).sum() == norm2[i] - assert np.allclose(state.sub(np.array([False, True, False, True])).state, - state.sub([1, 3]).state) + assert np.allclose( + state.sub(np.array([False, True, False, True])).state, state.sub([1, 3]).state + ) sub = state.sub(np.array([False, True, False, True])) state.sub([1, 3], inplace=True) @@ -209,8 +212,8 @@ def test_state_phase_all(): state = np.random.rand(10, 10) + 1j * np.random.rand(10, 10) state1 = State(state) state2 = State(-state) - ph1 = state1.phase('all') - ph2 = state2.phase('all') + ph1 = state1.phase("all") + ph2 = state2.phase("all") assert np.allclose(ph1, ph2 + np.pi) @@ -268,7 +271,7 @@ def test_state_align_norm2(): def test_state_rotate(): - state = State([[1+1.j, 1.], [0.1-0.1j, 0.1]]) + state = State([[1 + 1.0j, 1.0], [0.1 - 0.1j, 0.1]]) # Angles are 45 and -45 s = state.copy() @@ -277,7 +280,7 @@ def test_state_rotate(): assert -np.pi / 4 == pytest.approx(np.angle(s.state[1, 0])) assert 0 == pytest.approx(np.angle(s.state[1, 1])) - s.rotate() # individual false + s.rotate() # individual false assert 0 == pytest.approx(np.angle(s.state[0, 0])) assert -np.pi / 4 == pytest.approx(np.angle(s.state[0, 1])) assert -np.pi / 2 == pytest.approx(np.angle(s.state[1, 0])) diff --git a/src/sisl/quaternion.py b/src/sisl/quaternion.py index 73d6885287..a88b508a07 100644 --- a/src/sisl/quaternion.py +++ b/src/sisl/quaternion.py @@ -7,7 +7,7 @@ from ._internal import set_module -__all__ = ['Quaternion'] +__all__ = ["Quaternion"] @set_module("sisl") @@ -16,8 +16,8 @@ class Quaternion: Quaternion object to enable easy rotational quantities. """ - def __init__(self, angle=0., v=None, rad=False): - """ Create quaternion object with angle and vector """ + def __init__(self, angle=0.0, v=None, rad=False): + """Create quaternion object with angle and vector""" if rad: half = angle / 2 else: @@ -29,38 +29,38 @@ def __init__(self, angle=0., v=None, rad=False): self._v[1:] = np.array(v[:3], np.float64) * m.sin(half) def copy(self): - """ Return deepcopy of itself """ + """Return deepcopy of itself""" q = Quaternion() q._v = np.copy(self._v) return q def conj(self): - """ Returns the conjugate of it-self """ + """Returns the conjugate of it-self""" q = self.copy() q._v[1:] *= -1 return q def norm(self): - """ Returns the norm of this quaternion """ + """Returns the norm of this quaternion""" return np.sqrt(np.sum(self._v**2)) @property def degree(self): - """ Returns the angle associated with this quaternion (in degree)""" - return m.acos(self._v[0]) * 360. / m.pi + """Returns the angle associated with this quaternion (in degree)""" + return m.acos(self._v[0]) * 360.0 / m.pi @property def radian(self): - """ Returns the angle associated with this quaternion (in radians)""" - return m.acos(self._v[0]) * 2. + """Returns the angle associated with this quaternion (in radians)""" + return m.acos(self._v[0]) * 2.0 angle = radian def rotate(self, v): - """ Rotates 3-dimensional vector ``v`` with the associated quaternion """ + """Rotates 3-dimensional vector ``v`` with the associated quaternion""" if len(v.shape) == 1: q = self.copy() - q._v[0] = 1. + q._v[0] = 1.0 q._v[1:] = v[:] q = self * q * self.conj() return q._v[1:] @@ -78,29 +78,26 @@ def rotate(self, v): f[3, :] = v1[0] * v[:, 2] + v1[1] * v[:, 1] - v1[2] * v[:, 0] + v1[3] # Create actual rotated array nv = np.empty(v.shape, v.dtype) - nv[:, 0] = f[0, :] * v2[1] + f[1, :] * \ - v2[0] + f[2, :] * v2[3] - f[3, :] * v2[2] - nv[:, 1] = f[0, :] * v2[2] - f[1, :] * \ - v2[3] + f[2, :] * v2[0] + f[3, :] * v2[1] - nv[:, 2] = f[0, :] * v2[3] + f[1, :] * \ - v2[2] - f[2, :] * v2[1] + f[3, :] * v2[0] + nv[:, 0] = f[0, :] * v2[1] + f[1, :] * v2[0] + f[2, :] * v2[3] - f[3, :] * v2[2] + nv[:, 1] = f[0, :] * v2[2] - f[1, :] * v2[3] + f[2, :] * v2[0] + f[3, :] * v2[1] + nv[:, 2] = f[0, :] * v2[3] + f[1, :] * v2[2] - f[2, :] * v2[1] + f[3, :] * v2[0] del f # re-create shape nv.shape = s return nv def __eq__(self, other): - """ Returns whether two Quaternions are equal """ + """Returns whether two Quaternions are equal""" return np.allclose(self._v, other._v) def __neg__(self): - """ Returns the negative quaternion """ + """Returns the negative quaternion""" q = self.copy() q._v = -q._v return q def __add__(self, other): - """ Returns the added quantity """ + """Returns the added quantity""" q = self.copy() if isinstance(other, Quaternion): q._v += other._v @@ -109,7 +106,7 @@ def __add__(self, other): return q def __sub__(self, other): - """ Returns the subtracted quantity """ + """Returns the subtracted quantity""" q = self.copy() if isinstance(other, Quaternion): q._v -= other._v @@ -118,33 +115,31 @@ def __sub__(self, other): return q def __mul__(self, other): - """ Multiplies with another instance or scalar """ + """Multiplies with another instance or scalar""" q = self.copy() if isinstance(other, Quaternion): v1 = np.copy(self._v) v2 = other._v - q._v[0] = v1[0] * v2[0] - v1[1] * \ - v2[1] - v1[2] * v2[2] - v1[3] * v2[3] - q._v[1] = v1[0] * v2[1] + v1[1] * \ - v2[0] + v1[2] * v2[3] - v1[3] * v2[2] - q._v[2] = v1[0] * v2[2] - v1[1] * \ - v2[3] + v1[2] * v2[0] + v1[3] * v2[1] - q._v[3] = v1[0] * v2[3] + v1[1] * \ - v2[2] - v1[2] * v2[1] + v1[3] * v2[0] + q._v[0] = v1[0] * v2[0] - v1[1] * v2[1] - v1[2] * v2[2] - v1[3] * v2[3] + q._v[1] = v1[0] * v2[1] + v1[1] * v2[0] + v1[2] * v2[3] - v1[3] * v2[2] + q._v[2] = v1[0] * v2[2] - v1[1] * v2[3] + v1[2] * v2[0] + v1[3] * v2[1] + q._v[3] = v1[0] * v2[3] + v1[1] * v2[2] - v1[2] * v2[1] + v1[3] * v2[0] else: q._v *= other return q def __div__(self, other): - """ Divides with a scalar """ + """Divides with a scalar""" if isinstance(other, Quaternion): - raise ValueError("Do not know how to divide a quaternion " + - "with a quaternion.") - return self * (1. / other) + raise ValueError( + "Do not know how to divide a quaternion " + "with a quaternion." + ) + return self * (1.0 / other) + __truediv__ = __div__ def __iadd__(self, other): - """ In-place addition """ + """In-place addition""" if isinstance(other, Quaternion): self._v += other._v else: @@ -152,7 +147,7 @@ def __iadd__(self, other): return self def __isub__(self, other): - """ In-place subtraction """ + """In-place subtraction""" if isinstance(other, Quaternion): self._v -= other._v else: @@ -161,28 +156,26 @@ def __isub__(self, other): # The in-place operators def __imul__(self, other): - """ In-place multiplication """ + """In-place multiplication""" if isinstance(other, Quaternion): v1 = np.copy(self._v) v2 = other._v - self._v[0] = v1[0] * v2[0] - v1[1] * \ - v2[1] - v1[2] * v2[2] - v1[3] * v2[3] - self._v[1] = v1[0] * v2[1] + v1[1] * \ - v2[0] + v1[2] * v2[3] - v1[3] * v2[2] - self._v[2] = v1[0] * v2[2] - v1[1] * \ - v2[3] + v1[2] * v2[0] + v1[3] * v2[1] - self._v[3] = v1[0] * v2[3] + v1[1] * \ - v2[2] - v1[2] * v2[1] + v1[3] * v2[0] + self._v[0] = v1[0] * v2[0] - v1[1] * v2[1] - v1[2] * v2[2] - v1[3] * v2[3] + self._v[1] = v1[0] * v2[1] + v1[1] * v2[0] + v1[2] * v2[3] - v1[3] * v2[2] + self._v[2] = v1[0] * v2[2] - v1[1] * v2[3] + v1[2] * v2[0] + v1[3] * v2[1] + self._v[3] = v1[0] * v2[3] + v1[1] * v2[2] - v1[2] * v2[1] + v1[3] * v2[0] else: self._v *= other return self def __idiv__(self, other): - """ In-place division """ + """In-place division""" if isinstance(other, Quaternion): - raise ValueError("Do not know how to divide a quaternion " + - "with a quaternion.") + raise ValueError( + "Do not know how to divide a quaternion " + "with a quaternion." + ) # use imul self._v /= other return self + __itruediv__ = __idiv__ diff --git a/src/sisl/selector.py b/src/sisl/selector.py index 6d44af9284..fab46dc1ee 100644 --- a/src/sisl/selector.py +++ b/src/sisl/selector.py @@ -56,12 +56,12 @@ from ._internal import set_module from .messages import warn -__all__ = ['Selector', 'TimeSelector'] +__all__ = ["Selector", "TimeSelector"] @set_module("sisl") class Selector: - r""" Base class for implementing a selector of class routines + r"""Base class for implementing a selector of class routines This class should contain a list of routines and may then be used to always return the best performant routine. The *performance* @@ -85,10 +85,9 @@ class Selector: runned routines. """ - __slots__ = ['_routines', '_metric', '_best', '_ordered'] + __slots__ = ["_routines", "_metric", "_best", "_ordered"] def __init__(self, routines=None, ordered=False): - # Copy the routines to the list if routines is None: self._routines = [] @@ -120,21 +119,21 @@ def ordered(self): return self._ordered def __len__(self): - """ Number of routines that it can select from """ + """Number of routines that it can select from""" return len(self.routines) def __str__(self): - """ A representation of the current selector state """ - s = self.__class__.__name__ + '{{n={0}, \n'.format(len(self)) + """A representation of the current selector state""" + s = self.__class__.__name__ + "{{n={0}, \n".format(len(self)) for r, p in zip(self.routines, self.metric): if p is None: - s += f' {{{r.__name__}: }},\n' + s += f" {{{r.__name__}: }},\n" else: - s += f' {{{r.__name__}: {p}}},\n' - return s + '}' + s += f" {{{r.__name__}: {p}}},\n" + return s + "}" def prepend(self, routine): - """ Prepends a new routine to the selector + """Prepends a new routine to the selector Parameters ---------- @@ -147,7 +146,7 @@ def prepend(self, routine): self._best = None def append(self, routine): - """ Prepends a new routine to the selector + """Prepends a new routine to the selector Parameters ---------- @@ -169,12 +168,12 @@ def append(self, routine): self._best = None def reset(self): - """ Reset the metric table to redo the performance checks """ + """Reset the metric table to redo the performance checks""" self._metric = [None] * len(self._metric) self._best = None def select_best(self, routine=None): - """ Update the `best` routine, if applicable + """Update the `best` routine, if applicable Update the selector to choose the best method. If not all routines have been carried through, then @@ -195,10 +194,9 @@ def select_best(self, routine=None): best method """ if routine is None: - # Try and select the routine based on the internal runned # metric specifiers - selected, metric = -1, 0. + selected, metric = -1, 0.0 for i, v in enumerate(self.metric): if v is None: # Quick return if we are not done @@ -227,14 +225,16 @@ def select_best(self, routine=None): self._best = r break if self.best is None: - warn(self.__class__.__name__ + ' selection of ' - 'optimal routine is not in the list of available ' - 'routines. Will not select a routine.') + warn( + self.__class__.__name__ + " selection of " + "optimal routine is not in the list of available " + "routines. Will not select a routine." + ) else: self._best = routine def next(self): - """ Choose the next routine that requires metric analysis + """Choose the next routine that requires metric analysis Returns ------- @@ -252,13 +252,13 @@ def next(self): return -1, self._best def __call__(self, *args, **kwargs): - """ Call the function that optimizes the run-time the most + """Call the function that optimizes the run-time the most The first argument *must* be an object (`self`) while all remaining arguments are transferred to the routine calls """ if not self.best is None: - return self.best(*args, **kwargs) # pylint: disable=E1102 + return self.best(*args, **kwargs) # pylint: disable=E1102 # Figure out if we have the metric for all the routines idx, routine = self.next() @@ -280,7 +280,7 @@ def __call__(self, *args, **kwargs): return returns def start(self): - """ Start the metric profiler + """Start the metric profiler This routine should return an initial state value. The difference between `stop() - start()` should yield a @@ -293,7 +293,7 @@ def start(self): raise NotImplementedError def stop(self, start): - """ Stop the metric profiler + """Stop the metric profiler This routine returns the actual metric for the method. Its input is what `start` returns and it may be of any @@ -312,17 +312,17 @@ def stop(self, start): @set_module("sisl") class TimeSelector(Selector): - """ Routine metric selector based on timings for the routines """ + """Routine metric selector based on timings for the routines""" def start(self): - """ Start the timing routine """ + """Start the timing routine""" return time.time() def stop(self, start): - """ Stop the timing routine + """Stop the timing routine Returns ------- inv_time : metric for time """ - return 1. / (time.time() - start) + return 1.0 / (time.time() - start) diff --git a/src/sisl/shape/__init__.py b/src/sisl/shape/__init__.py index 47726a1245..f7a67d6e8d 100644 --- a/src/sisl/shape/__init__.py +++ b/src/sisl/shape/__init__.py @@ -42,4 +42,4 @@ from .ellipsoid import * from .prism4 import * -__all__ = [s for s in dir() if not s.startswith('_')] +__all__ = [s for s in dir() if not s.startswith("_")] diff --git a/src/sisl/shape/_cylinder.py b/src/sisl/shape/_cylinder.py index 5c5b6acbd8..4734948752 100644 --- a/src/sisl/shape/_cylinder.py +++ b/src/sisl/shape/_cylinder.py @@ -18,7 +18,7 @@ @set_module("sisl.shape") class EllipticalCylinder(PureShape): - r""" 3D elliptical cylinder + r"""3D elliptical cylinder Parameters ---------- @@ -47,7 +47,7 @@ class EllipticalCylinder(PureShape): >>> shape.within([1.4, 0, 1.1]) False """ - __slots__ = ('_v', '_nh', '_iv', '_h') + __slots__ = ("_v", "_nh", "_iv", "_h") def __init__(self, v, h: float, axes=(0, 1), center=None): super().__init__(center) @@ -59,13 +59,17 @@ def __init__(self, v, h: float, axes=(0, 1), center=None): elif v.size == 6: vv[:, :] = v else: - raise ValueError(f"{self.__class__.__name__} expected 'v' to be of size (1,), (2,) or (2, 3), got {v.shape}") + raise ValueError( + f"{self.__class__.__name__} expected 'v' to be of size (1,), (2,) or (2, 3), got {v.shape}" + ) # If the vectors are not orthogonal, orthogonalize them and issue a warning vv_ortho = np.fabs(vv @ vv.T - np.diag(fnorm2(vv))) if vv_ortho.sum() > 1e-9: - warn(f"{self.__class__.__name__ } principal vectors are not orthogonal. " - "sisl orthogonalizes the vectors (retaining 1st vector)!") + warn( + f"{self.__class__.__name__ } principal vectors are not orthogonal. " + "sisl orthogonalizes the vectors (retaining 1st vector)!" + ) vv[1] = orthogonalize(vv[0], vv[1]) @@ -89,11 +93,11 @@ def copy(self): return self.__class__(self.radial_vector, self.height, self.center) def volume(self): - """ Return the volume of the shape """ + """Return the volume of the shape""" return pi * np.product(self.radius) * self.height def scale(self, scale: float): - """ Create a new shape with all dimensions scaled according to `scale` + """Create a new shape with all dimensions scaled according to `scale` Parameters ---------- @@ -111,7 +115,7 @@ def scale(self, scale: float): return self.__class__(v, h, self.center) def expand(self, radius): - """ Expand elliptical cylinder by a constant value along each vector and height + """Expand elliptical cylinder by a constant value along each vector and height Parameters ---------- @@ -128,11 +132,13 @@ def expand(self, radius): v1 = expand(self._v[1], radius[1]) h = self.height + radius[2] else: - raise ValueError(f"{self.__class__.__name__}.expand requires the radius to be either (1,) or (3,)") + raise ValueError( + f"{self.__class__.__name__}.expand requires the radius to be either (1,) or (3,)" + ) return self.__class__([v0, v1], h, self.center) - def within_index(self, other, tol=1.e-8): - r""" Return indices of the points that are within the shape + def within_index(self, other, tol=1.0e-8): + r"""Return indices of the points that are within the shape Parameters ---------- @@ -150,30 +156,30 @@ def within_index(self, other, tol=1.e-8): # Get indices where we should do the more # expensive exact check of being inside shape # I.e. this reduces the search space to the box - return indices_in_cylinder(tmp, 1. + tol, 1. + tol) + return indices_in_cylinder(tmp, 1.0 + tol, 1.0 + tol) @property def height(self): - """ Height of the cylinder """ + """Height of the cylinder""" return self._h @property def radius(self): - """ Radius of the ellipse base vectors """ + """Radius of the ellipse base vectors""" return fnorm(self._v) @property def radial_vector(self): - """ The radial vectors """ + """The radial vectors""" return self._v @property def height_vector(self): - """ The height vector """ + """The height vector""" return self._nh def toSphere(self): - """ Convert to a sphere """ + """Convert to a sphere""" from .ellipsoid import Sphere # figure out the distance from the center to the edge (along longest radius) @@ -185,8 +191,9 @@ def toSphere(self): return Sphere(r, self.center.copy()) def toCuboid(self): - """ Return a cuboid with side lengths equal to the diameter of each ellipsoid vectors """ + """Return a cuboid with side lengths equal to the diameter of each ellipsoid vectors""" from .prism4 import Cuboid + return Cuboid([self._v[0], self._v[1], self._nh], self.center) @@ -196,6 +203,7 @@ def toCuboid(self): class EllipticalCylinderToSphere(ShapeToDispatch): def dispatch(self, *args, **kwargs): from .ellipsoid import Sphere + shape = self._get_object() # figure out the distance from the center to the edge (along longest radius) h = shape.height / 2 @@ -205,15 +213,18 @@ def dispatch(self, *args, **kwargs): # Rescale each vector return Sphere(r, shape.center.copy()) + to_dispatch.register("Sphere", EllipticalCylinderToSphere) class EllipticalCylinderToCuboid(ShapeToDispatch): def dispatch(self, *args, **kwargs): from .prism4 import Cuboid + shape = self._get_object() return Cuboid([shape._v[0], shape._v[1], shape._nh], shape.center) + to_dispatch.register("Cuboid", EllipticalCylinderToCuboid) del to_dispatch diff --git a/src/sisl/shape/base.py b/src/sisl/shape/base.py index 28af5ae86a..99fbd4660f 100644 --- a/src/sisl/shape/base.py +++ b/src/sisl/shape/base.py @@ -12,17 +12,28 @@ from sisl.messages import deprecation from sisl.utils.mathematics import fnorm -__all__ = ["Shape", "PureShape", "NullShape", "ShapeToDispatch", - "CompositeShape", "OrShape", "XOrShape", "AndShape", "SubShape"] +__all__ = [ + "Shape", + "PureShape", + "NullShape", + "ShapeToDispatch", + "CompositeShape", + "OrShape", + "XOrShape", + "AndShape", + "SubShape", +] @set_module("sisl.shape") -class Shape(_Dispatchs, - dispatchs=[("to", ClassDispatcher("to", - type_dispatcher=None, - obj_getattr="error"))], - when_subclassing="copy"): - """ Baseclass for all shapes. Logical operations are implemented on this class. +class Shape( + _Dispatchs, + dispatchs=[ + ("to", ClassDispatcher("to", type_dispatcher=None, obj_getattr="error")) + ], + when_subclassing="copy", +): + """Baseclass for all shapes. Logical operations are implemented on this class. **This class must be sub classed.** @@ -65,7 +76,8 @@ class Shape(_Dispatchs, center : (3,) the center of the shape """ - __slots__ = ('_center', ) + + __slots__ = ("_center",) def __init__(self, center=(0, 0, 0)): if center is None: @@ -74,35 +86,45 @@ def __init__(self, center=(0, 0, 0)): @property def center(self): - """ The geometric center of the shape """ + """The geometric center of the shape""" return self._center @center.setter def center(self, center): - """ Set the geometric center of the shape """ + """Set the geometric center of the shape""" self._center[:] = center def scale(self, scale): - """ Return a new Shape with a scaled size """ - raise NotImplementedError(f"{self.__class__.__name__}.scale has not been implemented") - - @deprecation("toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15") + """Return a new Shape with a scaled size""" + raise NotImplementedError( + f"{self.__class__.__name__}.scale has not been implemented" + ) + + @deprecation( + "toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15" + ) def toSphere(self, *args, **kwargs): - """ Create a sphere which is surely encompassing the *full* shape """ - raise NotImplementedError(f"{self.__class__.__name__}.toSphere has not been implemented") - - @deprecation("toEllipsoid is deprecated, please use shape.to.Ellipsoid(...) instead.", "0.15") + """Create a sphere which is surely encompassing the *full* shape""" + raise NotImplementedError( + f"{self.__class__.__name__}.toSphere has not been implemented" + ) + + @deprecation( + "toEllipsoid is deprecated, please use shape.to.Ellipsoid(...) instead.", "0.15" + ) def toEllipsoid(self, *args, **kwargs): - """ Create an ellipsoid which is surely encompassing the *full* shape """ + """Create an ellipsoid which is surely encompassing the *full* shape""" return self.to.Sphere().to.Ellipsoid(*args, **kwargs) - @deprecation("toCuboid is deprecated, please use shape.to.Cuboid(...) instead.", "0.15") + @deprecation( + "toCuboid is deprecated, please use shape.to.Cuboid(...) instead.", "0.15" + ) def toCuboid(self, *args, **kwargs): - """ Create a cuboid which is surely encompassing the *full* shape """ + """Create a cuboid which is surely encompassing the *full* shape""" return self.to.Ellipsoid().to.Cuboid(*args, **kwargs) def within(self, other, *args, **kwargs): - """ Return ``True`` if `other` is fully within `self` + """Return ``True`` if `other` is fully within `self` If `other` is an array, an array will be returned for each of these. @@ -128,11 +150,13 @@ def within(self, other, *args, **kwargs): return within def within_index(self, other, *args, **kwargs): - """ Return indices of the elements of `other` that are within the shape """ - raise NotImplementedError(f"{self.__class__.__name__}.within_index has not been implemented") + """Return indices of the elements of `other` that are within the shape""" + raise NotImplementedError( + f"{self.__class__.__name__}.within_index has not been implemented" + ) def __contains__(self, other): - """ Checks whether all of `other` is within the shape """ + """Checks whether all of `other` is within the shape""" return np.all(self.within(other)) def __str__(self): @@ -158,14 +182,18 @@ def __xor__(self, other): to_dispatch = Shape.to + # Add dispatcher systems class ShapeToDispatch(AbstractDispatch): - """ Base dispatcher from class passing from a Shape class """ + """Base dispatcher from class passing from a Shape class""" + class ToEllipsoidDispatch(ShapeToDispatch): def dispatch(self, *args, **kwargs): shape = self._get_object() return shape.to.Sphere(*args, **kwargs).to.Ellipsoid() + + to_dispatch.register("Ellipsoid", ToEllipsoidDispatch) @@ -173,12 +201,14 @@ class ToCuboidDispatch(ShapeToDispatch): def dispatch(self, *args, **kwargs): shape = self._get_object() return shape.to.Ellipsoid(*args, **kwargs).to.Cuboid() + + to_dispatch.register("Cuboid", ToCuboidDispatch) @set_module("sisl.shape") class CompositeShape(Shape): - """ A composite shape consisting of two shapes, an abstract class + """A composite shape consisting of two shapes, an abstract class This should take 2 shapes as arguments. @@ -189,27 +219,26 @@ class CompositeShape(Shape): B : Shape the right hand side of the set operation """ - __slots__ = ('A', 'B') + + __slots__ = ("A", "B") def __init__(self, A, B): self.A = A.copy() self.B = B.copy() - def __init_subclass__(cls, /, - composite_name: str, - **kwargs): + def __init_subclass__(cls, /, composite_name: str, **kwargs): super().__init_subclass__(**kwargs) cls.__slots__ = () cls.__str__ = _composite_name(composite_name) @property def center(self): - """ Average center of composite shapes """ + """Average center of composite shapes""" return (self.A.center + self.B.center) * 0.5 @staticmethod def volume(): - """ Volume of a composite shape is current undefined, so a negative number is returned (may change) """ + """Volume of a composite shape is current undefined, so a negative number is returned (may change)""" # The volume for these set operators cannot easily be defined, so # we should rather not do anything about it. # TODO we could *estimate* the volume by doing @@ -219,11 +248,13 @@ def volume(): # and calculate fractional volume # This is very inaccurate, but would probably be # good enough. - return -1. + return -1.0 - @deprecation("toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15") + @deprecation( + "toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15" + ) def toSphere(self, *args, **kwargs): - """ Create a sphere which is surely encompassing the *full* shape """ + """Create a sphere which is surely encompassing the *full* shape""" return self.to.Sphere(*args, **kwargs) def scale(self, scale): @@ -235,8 +266,9 @@ def copy(self): class ToSphereDispatch(ShapeToDispatch): def dispatch(self, center=None): - """ Create a sphere which is surely encompassing the *full* shape """ + """Create a sphere which is surely encompassing the *full* shape""" from .ellipsoid import Sphere + shape = self._get_object() # Retrieve spheres @@ -255,34 +287,39 @@ def dispatch(self, center=None): return Sphere(max(A, B), center) + CompositeShape.to.register("Sphere", ToSphereDispatch) + class ToEllipsoidDispatcher(ShapeToDispatch): def dispatch(self, *args, center=None, **kwargs): from .ellipsoid import Ellipsoid + shape = self._get_object() return shape.to.Sphere(center=center).to.Ellipsoid() + CompositeShape.to.register("Ellipsoid", ToEllipsoidDispatch) def _composite_name(sep): def _str(self): if isinstance(self.A, CompositeShape): - A = "({})".format(str(self.A).replace('\n', '\n ')) + A = "({})".format(str(self.A).replace("\n", "\n ")) else: - A = "{}".format(str(self.A).replace('\n', '\n ')) + A = "{}".format(str(self.A).replace("\n", "\n ")) if isinstance(self.B, CompositeShape): - B = "({})".format(str(self.B).replace('\n', '\n ')) + B = "({})".format(str(self.B).replace("\n", "\n ")) else: - B = "{}".format(str(self.B).replace('\n', '\n ')) + B = "{}".format(str(self.B).replace("\n", "\n ")) return f"{self.__class__.__name__}{{{A} {sep} {B}}}" + return _str @set_module("sisl.shape") class OrShape(CompositeShape, composite_name="|"): - """ Boolean ``A | B`` shape """ + """Boolean ``A | B`` shape""" def within_index(self, *args, **kwargs): A = self.A.within_index(*args, **kwargs) @@ -292,7 +329,7 @@ def within_index(self, *args, **kwargs): @set_module("sisl.shape") class XOrShape(CompositeShape, composite_name="^"): - """ Boolean ``A ^ B`` shape """ + """Boolean ``A ^ B`` shape""" def within_index(self, *args, **kwargs): A = self.A.within_index(*args, **kwargs) @@ -302,7 +339,7 @@ def within_index(self, *args, **kwargs): @set_module("sisl.shape") class SubShape(CompositeShape, composite_name="-"): - """ Boolean ``A - B`` shape """ + """Boolean ``A - B`` shape""" def within_index(self, *args, **kwargs): A = self.A.within_index(*args, **kwargs) @@ -312,11 +349,13 @@ def within_index(self, *args, **kwargs): @set_module("sisl.shape") class AndShape(CompositeShape, composite_name="&"): - """ Boolean ``A & B`` shape """ + """Boolean ``A & B`` shape""" - @deprecation("toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15") + @deprecation( + "toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15" + ) def toSphere(self, *args, **kwargs): - """ Create a sphere which is surely encompassing the *full* shape """ + """Create a sphere which is surely encompassing the *full* shape""" return self.to.Sphere(*args, **kwargs) def within_index(self, *args, **kwargs): @@ -324,10 +363,12 @@ def within_index(self, *args, **kwargs): B = self.B.within_index(*args, **kwargs) return np.intersect1d(A, B, assume_unique=True) + class AndToSphereDispatch(ShapeToDispatch): def dispatch(self, center=None): - """ Create a sphere which is surely encompassing the *full* shape """ + """Create a sphere which is surely encompassing the *full* shape""" from .ellipsoid import Sphere + shape = self._obj # Retrieve spheres @@ -353,7 +394,7 @@ def dispatch(self, center=None): elif dist <= (Ar + Br): # We can reduce the sphere drastically because only the overlapping region is important # i_r defines the intersection radius, search for Sphere-Sphere Intersection - dx = (dist ** 2 - Br ** 2 + Ar ** 2) / (2 * dist) + dx = (dist**2 - Br**2 + Ar**2) / (2 * dist) if dx > dist: # the intersection is placed after the radius of B @@ -362,7 +403,9 @@ def dispatch(self, center=None): elif dx < 0: return A - i_r = msqrt(4 * (dist * Ar) ** 2 - (dist ** 2 - Br ** 2 + Ar ** 2) ** 2) / (2 * dist) + i_r = msqrt(4 * (dist * Ar) ** 2 - (dist**2 - Br**2 + Ar**2) ** 2) / ( + 2 * dist + ) # Now we simply need to find the dx point along the vector Bc - Ac # Then we can easily calculate the point from A @@ -382,12 +425,13 @@ def dispatch(self, center=None): return Sphere(max(A, B), center) + AndShape.to.register("Sphere", AndToSphereDispatch) @set_module("sisl.shape") class PureShape(Shape): - """ Extension of the `Shape` class for additional well defined shapes + """Extension of the `Shape` class for additional well defined shapes This shape should be used when subclassing shapes where the volume of the shape is *exactly* known. @@ -395,19 +439,24 @@ class PureShape(Shape): `volume` return the volume of the shape. """ + __slots__ = () def volume(self, *args, **kwargs): - raise NotImplementedError(f"{self.__class__.__name__}.volume has not been implemented") + raise NotImplementedError( + f"{self.__class__.__name__}.volume has not been implemented" + ) def expand(self, _): - """ Expand the shape by a constant value """ - raise NotImplementedError(f"{self.__class__.__name__}.expand has not been implemented") + """Expand the shape by a constant value""" + raise NotImplementedError( + f"{self.__class__.__name__}.expand has not been implemented" + ) @set_module("sisl.shape") class NullShape(PureShape, dispatchs=[("to", "copy")]): - """ A unique shape which has no well-defined spatial volume or center + """A unique shape which has no well-defined spatial volume or center This special shape is used when composite shapes turns out to have a null space. @@ -419,63 +468,83 @@ class NullShape(PureShape, dispatchs=[("to", "copy")]): Since it has no volume of point in space, none of the arguments has any meaning. """ + __slots__ = () def __init__(self, *args, **kwargs): - """ Initialize a null-shape """ + """Initialize a null-shape""" M = np.finfo(np.float64).max / 100 self._center = np.array([M, M, M], np.float64) def within_index(self, other, *args, **kwargs): - """ Always returns a zero length array """ + """Always returns a zero length array""" return np.empty(0, dtype=np.int32) - @deprecation("toEllipsoid is deprecated, please use shape.to.Ellipsoid(...) instead.", "0.15") + @deprecation( + "toEllipsoid is deprecated, please use shape.to.Ellipsoid(...) instead.", "0.15" + ) def toEllipsoid(self, *args, **kwargs): - """ Return an ellipsoid with radius of size 1e-64 """ + """Return an ellipsoid with radius of size 1e-64""" return self.to.Ellipsoid(*args, **kwargs) - @deprecation("toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15") + @deprecation( + "toSphere is deprecated, please use shape.to.Sphere(...) instead.", "0.15" + ) def toSphere(self, *args, **kwargs): - """ Return a sphere with radius of size 1e-64 """ + """Return a sphere with radius of size 1e-64""" return self.to.Sphere(*args, **kwargs) - @deprecation("toCuboid is deprecated, please use shape.to.Cuboid(...) instead.", "0.15") + @deprecation( + "toCuboid is deprecated, please use shape.to.Cuboid(...) instead.", "0.15" + ) def toCuboid(self, *args, **kwargs): - """ Return a cuboid with side-lengths 1e-64 """ + """Return a cuboid with side-lengths 1e-64""" return self.to.Cuboid(*args, **kwargs) def volume(self, *args, **kwargs): - """ The volume of a null shape is exactly 0. """ - return 0. + """The volume of a null shape is exactly 0.""" + return 0.0 + to_dispatch = NullShape.to + class NullToSphere(ShapeToDispatch): def dispatch(self, *args, center=None, **kwargs): from .ellipsoid import Sphere + shape = self._get_object() if center is None: center = shape.center.copy() - return Sphere(1.e-64, center=center) + return Sphere(1.0e-64, center=center) + + to_dispatch.register("Sphere", NullToSphere) + class NullToEllipsoid(ShapeToDispatch): def dispatch(self, *args, center=None, **kwargs): from .ellipsoid import Ellipsoid + shape = self._get_object() if center is None: center = shape.center.copy() - return Ellipsoid(1.e-64, center=center) + return Ellipsoid(1.0e-64, center=center) + + to_dispatch.register("Ellipsoid", NullToEllipsoid) + class NullToCuboid(ShapeToDispatch): def dispatch(self, *args, center=None, origin=None, **kwargs): from .prism4 import Cuboid + shape = self._get_object() if center is None and origin is None: center = shape.center.copy() - return Cuboid(1.e-64, center=center, origin=origin) + return Cuboid(1.0e-64, center=center, origin=origin) + + to_dispatch.register("Cuboid", NullToCuboid) del to_dispatch diff --git a/src/sisl/shape/ellipsoid.py b/src/sisl/shape/ellipsoid.py index 5fb2cbe0fd..0344200045 100644 --- a/src/sisl/shape/ellipsoid.py +++ b/src/sisl/shape/ellipsoid.py @@ -19,7 +19,7 @@ @set_module("sisl.shape") class Ellipsoid(PureShape): - """ 3D Ellipsoid shape + """3D Ellipsoid shape Parameters ---------- @@ -37,25 +37,31 @@ class Ellipsoid(PureShape): >>> shape.within([0, 2, 0]) True """ - __slots__ = ('_v', '_iv') + + __slots__ = ("_v", "_iv") def __init__(self, v, center=None): super().__init__(center) v = _a.asarrayd(v) if v.size == 1: - self._v = np.identity(3, np.float64) * v # a "Euclidean" sphere + self._v = np.identity(3, np.float64) * v # a "Euclidean" sphere elif v.size == 3: - self._v = np.diag(v.ravel()) # a "Euclidean" ellipsoid + self._v = np.diag(v.ravel()) # a "Euclidean" ellipsoid elif v.size == 9: self._v = v.reshape(3, 3).astype(np.float64) else: - raise ValueError(self.__class__.__name__ + " requires initialization with 3 vectors defining the ellipsoid") + raise ValueError( + self.__class__.__name__ + + " requires initialization with 3 vectors defining the ellipsoid" + ) # If the vectors are not orthogonal, orthogonalize them and issue a warning vv = np.fabs(np.dot(self._v, self._v.T) - np.diag(fnorm2(self._v))) if vv.sum() > 1e-9: - warn(self.__class__.__name__ + ' principal vectors are not orthogonal. ' - 'sisl orthogonalizes the vectors (retaining 1st vector)!') + warn( + self.__class__.__name__ + " principal vectors are not orthogonal. " + "sisl orthogonalizes the vectors (retaining 1st vector)!" + ) self._v[1, :] = orthogonalize(self._v[0, :], self._v[1, :]) self._v[2, :] = orthogonalize(self._v[0, :], self._v[2, :]) @@ -69,15 +75,16 @@ def copy(self): def __str__(self): cr = np.array([self.center, self.radius]) - return self.__class__.__name__ + ('{{c({0:.2f} {1:.2f} {2:.2f}) ' - 'r({3:.2f} {4:.2f} {5:.2f})}}').format(*cr.ravel()) + return self.__class__.__name__ + ( + "{{c({0:.2f} {1:.2f} {2:.2f}) " "r({3:.2f} {4:.2f} {5:.2f})}}" + ).format(*cr.ravel()) def volume(self): - """ Return the volume of the shape """ - return 4. / 3. * pi * product3(self.radius) + """Return the volume of the shape""" + return 4.0 / 3.0 * pi * product3(self.radius) def scale(self, scale): - """ Return a new shape with a larger corresponding to `scale` + """Return a new shape with a larger corresponding to `scale` Parameters ---------- @@ -90,7 +97,7 @@ def scale(self, scale): return self.__class__(self._v * scale, self.center) def expand(self, radius): - """ Expand ellipsoid by a constant value along each radial vector + """Expand ellipsoid by a constant value along each radial vector Parameters ---------- @@ -107,27 +114,35 @@ def expand(self, radius): v1 = expand(self._v[1, :], radius[1]) v2 = expand(self._v[2, :], radius[2]) else: - raise ValueError(self.__class__.__name__ + '.expand requires the radius to be either (1,) or (3,)') + raise ValueError( + self.__class__.__name__ + + ".expand requires the radius to be either (1,) or (3,)" + ) return self.__class__([v0, v1, v2], self.center) def toEllipsoid(self): - """ Return an ellipsoid that encompass this shape (a copy) """ + """Return an ellipsoid that encompass this shape (a copy)""" return self.copy() - @deprecation("toSphere is deprecated, please use shape.to['sphere'](...) instead.", "0.15") + @deprecation( + "toSphere is deprecated, please use shape.to['sphere'](...) instead.", "0.15" + ) def toSphere(self): - """ Return a sphere with a radius equal to the largest radial vector """ + """Return a sphere with a radius equal to the largest radial vector""" r = self.radius.max() return Sphere(r, self.center) - @deprecation("toCuboid is deprecated, please use shape.to['cuboid'](...) instead.", "0.15") + @deprecation( + "toCuboid is deprecated, please use shape.to['cuboid'](...) instead.", "0.15" + ) def toCuboid(self): - """ Return a cuboid with side lengths equal to the diameter of each ellipsoid vectors """ + """Return a cuboid with side lengths equal to the diameter of each ellipsoid vectors""" from .prism4 import Cuboid + return Cuboid(self._v * 2, self.center) - def within_index(self, other, tol=1.e-8): - r""" Return indices of the points that are within the shape + def within_index(self, other, tol=1.0e-8): + r"""Return indices of the points that are within the shape Parameters ---------- @@ -145,16 +160,16 @@ def within_index(self, other, tol=1.e-8): # Get indices where we should do the more # expensive exact check of being inside shape # I.e. this reduces the search space to the box - return indices_in_sphere(tmp, 1. + tol) + return indices_in_sphere(tmp, 1.0 + tol) @property def radius(self): - """ Return the radius of the Ellipsoid """ + """Return the radius of the Ellipsoid""" return fnorm(self._v) @property def radial_vector(self): - """ The radial vectors """ + """The radial vectors""" return self._v @@ -165,6 +180,7 @@ class EllipsoidToEllipsoid(ShapeToDispatch): def dispatch(self, *args, **kwargs): return self._get_object().copy() + to_dispatch.register("Ellipsoid", EllipsoidToEllipsoid) @@ -173,15 +189,18 @@ def dispatch(self, *args, **kwargs): shape = self._get_object() return Sphere(shape.radius.max(), shape.center) + to_dispatch.register("Sphere", EllipsoidToSphere) class EllipsoidToCuboid(ShapeToDispatch): def dispatch(self, *args, **kwargs): from .prism4 import Cuboid + shape = self._get_object() return Cuboid(shape._v * 2, shape.center) + to_dispatch.register("Cuboid", EllipsoidToCuboid) del to_dispatch @@ -189,35 +208,40 @@ def dispatch(self, *args, **kwargs): @set_module("sisl.shape") class Sphere(Ellipsoid, dispatchs=[("to", "keep")]): - """ 3D Sphere + """3D Sphere Parameters ---------- r : float radius of the sphere """ + __slots__ = () def __init__(self, radius, center=None): radius = _a.asarrayd(radius).ravel() if len(radius) > 1: - raise ValueError(self.__class__.__name__ + ' is defined via a single radius. ' - 'An array with more than 1 element is not an allowed argument ' - 'to __init__.') + raise ValueError( + self.__class__.__name__ + " is defined via a single radius. " + "An array with more than 1 element is not an allowed argument " + "to __init__." + ) super().__init__(radius, center=center) def __str__(self): - return '{0}{{c({2:.2f} {3:.2f} {4:.2f}) r({1:.2f})}}'.format(self.__class__.__name__, self.radius, *self.center) + return "{0}{{c({2:.2f} {3:.2f} {4:.2f}) r({1:.2f})}}".format( + self.__class__.__name__, self.radius, *self.center + ) def copy(self): return self.__class__(self.radius, self.center) def volume(self): - """ Return the volume of the sphere """ - return 4. / 3. * pi * self.radius ** 3 + """Return the volume of the sphere""" + return 4.0 / 3.0 * pi * self.radius**3 def scale(self, scale): - """ Return a new sphere with a larger radius + """Return a new sphere with a larger radius Parameters ---------- @@ -227,7 +251,7 @@ def scale(self, scale): return self.__class__(self.radius * scale, self.center) def expand(self, radius): - """ Expand sphere by a constant radius + """Expand sphere by a constant radius Parameters ---------- @@ -238,15 +262,20 @@ def expand(self, radius): @property def radius(self): - """ Return the radius of the Sphere """ + """Return the radius of the Sphere""" return self._v[0, 0] - @deprecation("toSphere is deprecated, please use shape.to['sphere'](...) instead.", "0.15") + @deprecation( + "toSphere is deprecated, please use shape.to['sphere'](...) instead.", "0.15" + ) def toSphere(self): - """ Return a copy of it-self """ + """Return a copy of it-self""" return self.copy() - @deprecation("toEllipsoid is deprecated, please use shape.to['ellipsoid'](...) instead.", "0.15") + @deprecation( + "toEllipsoid is deprecated, please use shape.to['ellipsoid'](...) instead.", + "0.15", + ) def toEllipsoid(self): - """ Convert this sphere into an ellipsoid """ + """Convert this sphere into an ellipsoid""" return Ellipsoid(self.radius, self.center) diff --git a/src/sisl/shape/prism4.py b/src/sisl/shape/prism4.py index 2dc371bb65..c946a456f9 100644 --- a/src/sisl/shape/prism4.py +++ b/src/sisl/shape/prism4.py @@ -13,12 +13,12 @@ from .base import PureShape, ShapeToDispatch -__all__ = ['Cuboid', 'Cube'] +__all__ = ["Cuboid", "Cube"] @set_module("sisl.shape") class Cuboid(PureShape): - """ A cuboid/rectangular prism (P4) + """A cuboid/rectangular prism (P4) Parameters ---------- @@ -40,22 +40,26 @@ class Cuboid(PureShape): >>> shape.within([0, 1.1, 0]) True """ - __slots__ = ('_v', '_iv') - def __init__(self, v, center=None, origin=None): + __slots__ = ("_v", "_iv") + def __init__(self, v, center=None, origin=None): v = _a.asarrayd(v) if v.size == 1: - self._v = np.identity(3) * v # a "Euclidean" cube + self._v = np.identity(3) * v # a "Euclidean" cube elif v.size == 3: - self._v = np.diag(v.ravel()) # a "Euclidean" rectangle + self._v = np.diag(v.ravel()) # a "Euclidean" rectangle elif v.size == 9: self._v = v.reshape(3, 3).astype(np.float64) else: - raise ValueError(f"{self.__class__.__name__} requires initialization with 3 vectors defining the cuboid") + raise ValueError( + f"{self.__class__.__name__} requires initialization with 3 vectors defining the cuboid" + ) if center is not None and origin is not None: - raise ValueError(f"{self.__class__.__name__} only allows either origin or center argument") + raise ValueError( + f"{self.__class__.__name__} only allows either origin or center argument" + ) elif origin is not None: center = self._v.sum(0) / 2 + origin @@ -69,14 +73,16 @@ def copy(self): return self.__class__(self._v, self.center) def __str__(self): - return self.__class__.__name__ + '{{O({1} {2} {3}), vol: {0}}}'.format(self.volume(), *self.origin) + return self.__class__.__name__ + "{{O({1} {2} {3}), vol: {0}}}".format( + self.volume(), *self.origin + ) def volume(self): - """ Return volume of Cuboid """ + """Return volume of Cuboid""" return abs(dot3(self._v[0, :], cross3(self._v[1, :], self._v[2, :]))) def scale(self, scale): - """ Scale the cuboid box size (center is retained) + """Scale the cuboid box size (center is retained) Parameters ---------- @@ -89,7 +95,7 @@ def scale(self, scale): return self.__class__(self._v * scale, self.center) def expand(self, length): - """ Expand the cuboid by a constant value along side vectors + """Expand the cuboid by a constant value along side vectors Parameters ---------- @@ -106,31 +112,41 @@ def expand(self, length): v1 = expand(self._v[1, :], length[1]) v2 = expand(self._v[2, :], length[2]) else: - raise ValueError(self.__class__.__name__ + '.expand requires the length to be either (1,) or (3,)') + raise ValueError( + self.__class__.__name__ + + ".expand requires the length to be either (1,) or (3,)" + ) return self.__class__([v0, v1, v2], self.center) - @deprecation("toEllipsoid is deprecated, please use shape.to['ellipsoid'](...) instead.", "0.15") + @deprecation( + "toEllipsoid is deprecated, please use shape.to['ellipsoid'](...) instead.", + "0.15", + ) def toEllipsoid(self): - """ Return an ellipsoid that encompass this cuboid """ + """Return an ellipsoid that encompass this cuboid""" from .ellipsoid import Ellipsoid # Rescale each vector - return Ellipsoid(self._v / 2 * 3 ** .5, self.center.copy()) + return Ellipsoid(self._v / 2 * 3**0.5, self.center.copy()) - @deprecation("toSphere is deprecated, please use shape.to['sphere'](...) instead.", "0.15") + @deprecation( + "toSphere is deprecated, please use shape.to['sphere'](...) instead.", "0.15" + ) def toSphere(self): - """ Return a sphere that encompass this cuboid """ + """Return a sphere that encompass this cuboid""" from .ellipsoid import Sphere - return Sphere(self.edge_length.max() / 2 * 3 ** .5, self.center.copy()) + return Sphere(self.edge_length.max() / 2 * 3**0.5, self.center.copy()) - @deprecation("toCuboid is deprecated, please use shape.to['cuboid'](...) instead.", "0.15") + @deprecation( + "toCuboid is deprecated, please use shape.to['cuboid'](...) instead.", "0.15" + ) def toCuboid(self): - """ Return a copy of itself """ + """Return a copy of itself""" return self.copy() - def within_index(self, other, tol=1.e-8): - """ Return indices of the `other` object which are contained in the shape + def within_index(self, other, tol=1.0e-8): + """Return indices of the `other` object which are contained in the shape Parameters ---------- @@ -148,49 +164,57 @@ def within_index(self, other, tol=1.e-8): # The proximity is 1e-12 of the inverse cell. # So, sadly, the bigger the cell the bigger the tolerance # However due to numerics this is probably best anyway - return indices_gt_le(tmp, -tol, 1. + tol) + return indices_gt_le(tmp, -tol, 1.0 + tol) @property def origin(self): - """ Return the origin of the Cuboid (lower-left corner) """ + """Return the origin of the Cuboid (lower-left corner)""" return self.center - (self._v * 0.5).sum(0) @origin.setter def origin(self, origin): - """ Re-setting the origin can sometimes be necessary """ + """Re-setting the origin can sometimes be necessary""" self.center = origin + (self._v * 0.5).sum(0) @property def edge_length(self): - """ The lengths of each of the vector that defines the cuboid """ + """The lengths of each of the vector that defines the cuboid""" return fnorm(self._v) to_dispatch = Cuboid.to + class CuboidToEllipsoid(ShapeToDispatch): def dispatch(self, *args, **kwargs): from .ellipsoid import Ellipsoid + shape = self._get_object() # Rescale each vector - return Ellipsoid(shape._v / 2 * 3 ** .5, shape.center.copy()) + return Ellipsoid(shape._v / 2 * 3**0.5, shape.center.copy()) + to_dispatch.register("Ellipsoid", CuboidToEllipsoid) + class CuboidToSphere(ShapeToDispatch): def dispatch(self, *args, **kwargs): from .ellipsoid import Sphere + shape = self._get_object() # Rescale each vector - return Sphere(shape.edge_length.max() / 2 * 3 ** .5, shape.center.copy()) + return Sphere(shape.edge_length.max() / 2 * 3**0.5, shape.center.copy()) + to_dispatch.register("Sphere", CuboidToSphere) + class CuboidToCuboid(ShapeToDispatch): def dispatch(self, *args, **kwargs): shape = self._get_object() return shape.copy() + to_dispatch.register("Cuboid", CuboidToCuboid) del to_dispatch @@ -198,7 +222,7 @@ def dispatch(self, *args, **kwargs): @set_module("sisl.shape") class Cube(Cuboid, dispatchs=[("to", "keep")]): - """ 3D Cube with equal sides + """3D Cube with equal sides Equivalent to ``Cuboid([r, r, r])``. @@ -213,6 +237,7 @@ class Cube(Cuboid, dispatchs=[("to", "keep")]): the lower left corner of the cuboid. Not allowed as argument if `center` is passed. """ + __slots__ = () def __init__(self, side, center=None, origin=None): diff --git a/src/sisl/shape/tests/test_cylinder.py b/src/sisl/shape/tests/test_cylinder.py index a15770ef4d..7072a77db5 100644 --- a/src/sisl/shape/tests/test_cylinder.py +++ b/src/sisl/shape/tests/test_cylinder.py @@ -14,16 +14,16 @@ @pytest.mark.filterwarnings("ignore", message="*orthogonalizes the vectors") def test_create_ellipticalcylinder(): - el = EllipticalCylinder(1., 1.) - el = EllipticalCylinder([1., 1.], 1.) - v0 = [1., 0.2, 1.0] - v1 = [1., -0.2, 1.0] - el = EllipticalCylinder([v0, v1], 1.) + el = EllipticalCylinder(1.0, 1.0) + el = EllipticalCylinder([1.0, 1.0], 1.0) + v0 = [1.0, 0.2, 1.0] + v1 = [1.0, -0.2, 1.0] + el = EllipticalCylinder([v0, v1], 1.0) v0 = el.radial_vector[0] v1 = el.radial_vector[1] v0 /= fnorm(v0) v1 /= fnorm(v1) - el = EllipticalCylinder([v0, v1], 1.) + el = EllipticalCylinder([v0, v1], 1.0) e2 = el.scale(1.1) assert np.allclose(el.radius + 0.1, e2.radius) assert np.allclose(el.height + 0.1, e2.height) @@ -37,7 +37,7 @@ def test_create_ellipticalcylinder(): def test_ellipticalcylinder_within(): - el = EllipticalCylinder(1., 1.) + el = EllipticalCylinder(1.0, 1.0) # center of cylinder assert el.within_index([0, 0, 0])[0] == 0 # should not be in a circle @@ -45,10 +45,10 @@ def test_ellipticalcylinder_within(): def test_tosphere(): - el = EllipticalCylinder([1., 1.], 1.) + el = EllipticalCylinder([1.0, 1.0], 1.0) el.to.Sphere() def test_tocuboid(): - el = EllipticalCylinder([1., 1.], 1.) + el = EllipticalCylinder([1.0, 1.0], 1.0) el.to.Cuboid() diff --git a/src/sisl/shape/tests/test_ellipsoid.py b/src/sisl/shape/tests/test_ellipsoid.py index f4d7081e4b..242e467d5b 100644 --- a/src/sisl/shape/tests/test_ellipsoid.py +++ b/src/sisl/shape/tests/test_ellipsoid.py @@ -15,13 +15,13 @@ @pytest.mark.filterwarnings("ignore", message="*orthogonalizes the vectors") def test_create_ellipsoid(): - el = Ellipsoid(1.) - el = Ellipsoid([1., 1., 1.]) - el = Ellipsoid([1., 1., 1.], [1.] * 3) - el = Ellipsoid([1., 2., 3.]) - v0 = [1., 0.2, 1.0] - v1 = [1., -0.2, 1.0] - v2 = [1., -0.2, -1.0] + el = Ellipsoid(1.0) + el = Ellipsoid([1.0, 1.0, 1.0]) + el = Ellipsoid([1.0, 1.0, 1.0], [1.0] * 3) + el = Ellipsoid([1.0, 2.0, 3.0]) + v0 = [1.0, 0.2, 1.0] + v1 = [1.0, -0.2, 1.0] + v2 = [1.0, -0.2, -1.0] el = Ellipsoid([v0, v1, v2]) q = Quaternion(45, [1, 0, 0]) eye = np.identity(3) @@ -33,45 +33,45 @@ def test_create_ellipsoid(): assert np.allclose(el.radius + 0.1, e2.radius) e2 = el.scale([1.1, 2.1, 3.1]) assert np.allclose(el.radius + [0.1, 1.1, 2.1], e2.radius) - assert el.expand(2).volume() == pytest.approx(4/3 * np.pi * 3 ** 3) + assert el.expand(2).volume() == pytest.approx(4 / 3 * np.pi * 3**3) str(el) def test_tosphere(): - el = Ellipsoid([1., 1., 1.]) + el = Ellipsoid([1.0, 1.0, 1.0]) assert el.to.Sphere().radius == pytest.approx(1) - el = Ellipsoid([1., 2., 3.]) + el = Ellipsoid([1.0, 2.0, 3.0]) assert el.to.Sphere().radius == pytest.approx(3) @pytest.mark.filterwarnings("ignore", message="*orthogonalizes the vectors") def test_create_ellipsoid_fail(): - v0 = [1., 0.2, 1.0] - v1 = [1., 0.2, 1.0] - v2 = [1., -0.2, -1.0] + v0 = [1.0, 0.2, 1.0] + v1 = [1.0, 0.2, 1.0] + v2 = [1.0, -0.2, -1.0] with pytest.raises(ValueError): el = Ellipsoid([v0, v1, v2]) def test_create_ellipsoid_fail2(): - v0 = [1., 0.2, 1.0] - v1 = [1., 0.2, 1.0] - v2 = [1., -0.2, -1.0] - v3 = [1., -0.2, -1.0] + v0 = [1.0, 0.2, 1.0] + v1 = [1.0, 0.2, 1.0] + v2 = [1.0, -0.2, -1.0] + v3 = [1.0, -0.2, -1.0] with pytest.raises(ValueError): el = Ellipsoid([v0, v1, v2, v3]) def test_create_sphere(): - el = Sphere(1.) - el = Sphere(1., center=[1.]*3) - assert el.volume() == pytest.approx(4/3 * np.pi) - assert el.scale(2).volume() == pytest.approx(4/3 * np.pi * 2 ** 3) - assert el.expand(2).volume() == pytest.approx(4/3 * np.pi * 3 ** 3) + el = Sphere(1.0) + el = Sphere(1.0, center=[1.0] * 3) + assert el.volume() == pytest.approx(4 / 3 * np.pi) + assert el.scale(2).volume() == pytest.approx(4 / 3 * np.pi * 2**3) + assert el.expand(2).volume() == pytest.approx(4 / 3 * np.pi * 3**3) def test_scale1(): - e1 = Ellipsoid([1., 1., 1.]) + e1 = Ellipsoid([1.0, 1.0, 1.0]) e2 = e1.scale(1.1) assert np.allclose(e1.radius + 0.1, e2.radius) e2 = e1.scale([1.1] * 3) @@ -81,7 +81,7 @@ def test_scale1(): def test_expand1(): - e1 = Ellipsoid([1., 1., 1.]) + e1 = Ellipsoid([1.0, 1.0, 1.0]) e2 = e1.expand(1.1) assert np.allclose(e1.radius + 1.1, e2.radius) e2 = e1.expand([1.1] * 3) @@ -97,47 +97,47 @@ def test_expand_fail(): def test_within1(): - o = Ellipsoid([1., 2., 3.]) - assert not o.within([-1.]*3).any() - assert o.within([.2]*3) - assert o.within([.5]*3) - o = Ellipsoid([1., 1., 2.]) - assert not o.within([-1.]*3).any() - assert o.within([.2]*3) - assert o.within([.5]*3) - o = Sphere(1.) - assert not o.within([-1.]*3).any() - assert o.within([.2]*3) - assert o.within([.5]*3) + o = Ellipsoid([1.0, 2.0, 3.0]) + assert not o.within([-1.0] * 3).any() + assert o.within([0.2] * 3) + assert o.within([0.5] * 3) + o = Ellipsoid([1.0, 1.0, 2.0]) + assert not o.within([-1.0] * 3).any() + assert o.within([0.2] * 3) + assert o.within([0.5] * 3) + o = Sphere(1.0) + assert not o.within([-1.0] * 3).any() + assert o.within([0.2] * 3) + assert o.within([0.5] * 3) def test_within_index1(): - o = Ellipsoid([1., 2., 3.]) - assert o.within_index([-1.]*3).size == 0 - assert o.within_index([.2]*3) == [0] - assert o.within_index([.5]*3) == [0] - o = Ellipsoid([1., 1., 2.]) - assert o.within_index([-1.]*3).size == 0 - assert o.within_index([.2]*3) == [0] - assert o.within_index([.5]*3) == [0] - o = Sphere(1.) - assert o.within_index([-1.]*3).size == 0 - assert o.within_index([.2]*3) == [0] - assert o.within_index([.5]*3) == [0] + o = Ellipsoid([1.0, 2.0, 3.0]) + assert o.within_index([-1.0] * 3).size == 0 + assert o.within_index([0.2] * 3) == [0] + assert o.within_index([0.5] * 3) == [0] + o = Ellipsoid([1.0, 1.0, 2.0]) + assert o.within_index([-1.0] * 3).size == 0 + assert o.within_index([0.2] * 3) == [0] + assert o.within_index([0.5] * 3) == [0] + o = Sphere(1.0) + assert o.within_index([-1.0] * 3).size == 0 + assert o.within_index([0.2] * 3) == [0] + assert o.within_index([0.5] * 3) == [0] def test_sphere_and(): # For all intersections after - v = np.array([1.] * 3) + v = np.array([1.0] * 3) v /= fnorm(v) D = np.linspace(0, 5, 50) inside = np.ones(len(D), dtype=bool) - A = Sphere(2.) + A = Sphere(2.0) is_first = True inside[:] = True for i, d in enumerate(D): - B = Sphere(1., center=v * d) + B = Sphere(1.0, center=v * d) C = (A & B).to.Sphere() if is_first and C.radius < B.radius: is_first = False @@ -151,7 +151,7 @@ def test_sphere_and(): is_first = True inside[:] = True for i, d in enumerate(D): - B = Sphere(1., center=v * d) + B = Sphere(1.0, center=v * d) C = (A & B).to.Sphere() str(A) + str(B) + str(C) if is_first and C.radius < A.radius: diff --git a/src/sisl/shape/tests/test_prism4.py b/src/sisl/shape/tests/test_prism4.py index 92673f9c9a..b86398db6a 100644 --- a/src/sisl/shape/tests/test_prism4.py +++ b/src/sisl/shape/tests/test_prism4.py @@ -12,22 +12,22 @@ def test_create_cuboid(): - cube = Cuboid([1.0]*3) - cube = Cuboid([1.0]*3, [1.]*3) - cube = Cuboid([1.0, 2.0, 3.0], [1.]*3) - cube = Cuboid([1.0, 2.0, 3.0], origin=[1.]*3) - v0 = [1., 0.2, 1.0] - v1 = [1., -0.2, 1.0] - v2 = [1., -0.2, -1.0] + cube = Cuboid([1.0] * 3) + cube = Cuboid([1.0] * 3, [1.0] * 3) + cube = Cuboid([1.0, 2.0, 3.0], [1.0] * 3) + cube = Cuboid([1.0, 2.0, 3.0], origin=[1.0] * 3) + v0 = [1.0, 0.2, 1.0] + v1 = [1.0, -0.2, 1.0] + v2 = [1.0, -0.2, -1.0] cube = Cuboid([v0, v1, v2]) str(cube) def test_create_fail(): - v0 = [1., 0.2, 1.0] - v1 = [1., 0.2, 1.0] - v2 = [1., -0.2, -1.0] - v3 = [1., -0.2, -1.0] + v0 = [1.0, 0.2, 1.0] + v1 = [1.0, 0.2, 1.0] + v2 = [1.0, -0.2, -1.0] + v3 = [1.0, -0.2, -1.0] with pytest.raises(ValueError): el = Cuboid([v0, v1, v2, v3]) with pytest.raises(ValueError): @@ -35,35 +35,35 @@ def test_create_fail(): def test_tosphere(): - cube = Cube(1.) - assert cube.to.Sphere().radius == pytest.approx(.5 * 3 ** 0.5) - cube = Cube(3.) - assert cube.to.Sphere().radius == pytest.approx(1.5 * 3 ** 0.5) - cube = Cuboid([1., 2., 3.]) - assert cube.to.Sphere().radius == pytest.approx(1.5 * 3 ** 0.5) - assert cube.to.Sphere().radius == pytest.approx(1.5 * 3 ** 0.5) + cube = Cube(1.0) + assert cube.to.Sphere().radius == pytest.approx(0.5 * 3**0.5) + cube = Cube(3.0) + assert cube.to.Sphere().radius == pytest.approx(1.5 * 3**0.5) + cube = Cuboid([1.0, 2.0, 3.0]) + assert cube.to.Sphere().radius == pytest.approx(1.5 * 3**0.5) + assert cube.to.Sphere().radius == pytest.approx(1.5 * 3**0.5) assert isinstance(cube.to.Sphere(), Sphere) def test_toellipsoid(): - cube = Cube(1.) - assert cube.to.Ellipsoid().radius[0] == pytest.approx(.5 * 3 ** 0.5) - cube = Cube(3.) - assert cube.to.Ellipsoid().radius[0] == pytest.approx(1.5 * 3 ** 0.5) - cube = Cuboid([1., 2., 3.]) - assert cube.to.Ellipsoid().radius[0] == pytest.approx(.5 * 3 ** 0.5) - assert cube.to.Ellipsoid().radius[1] == pytest.approx(1 * 3 ** 0.5) - assert cube.to.Ellipsoid().radius[2] == pytest.approx(1.5 * 3 ** 0.5) + cube = Cube(1.0) + assert cube.to.Ellipsoid().radius[0] == pytest.approx(0.5 * 3**0.5) + cube = Cube(3.0) + assert cube.to.Ellipsoid().radius[0] == pytest.approx(1.5 * 3**0.5) + cube = Cuboid([1.0, 2.0, 3.0]) + assert cube.to.Ellipsoid().radius[0] == pytest.approx(0.5 * 3**0.5) + assert cube.to.Ellipsoid().radius[1] == pytest.approx(1 * 3**0.5) + assert cube.to.Ellipsoid().radius[2] == pytest.approx(1.5 * 3**0.5) def test_create_cube(): cube = Cube(1.0) - cube = Cube(1.0, [1.]*3) - assert cube.volume() == pytest.approx(1.) - assert cube.scale(2).volume() == pytest.approx(2 ** 3) - assert cube.scale([2] * 3).volume() == pytest.approx(2 ** 3) - assert cube.expand(2).volume() == pytest.approx(3 ** 3) - assert cube.expand([2] * 3).volume() == pytest.approx(3 ** 3) + cube = Cube(1.0, [1.0] * 3) + assert cube.volume() == pytest.approx(1.0) + assert cube.scale(2).volume() == pytest.approx(2**3) + assert cube.scale([2] * 3).volume() == pytest.approx(2**3) + assert cube.expand(2).volume() == pytest.approx(3**3) + assert cube.expand([2] * 3).volume() == pytest.approx(3**3) def test_expand_fail(): @@ -73,40 +73,38 @@ def test_expand_fail(): def test_vol1(): - cube = Cuboid([1.0]*3) - assert cube.volume() == 1. - cube = Cuboid([1., 2., 3.]) - assert cube.volume() == 6. + cube = Cuboid([1.0] * 3) + assert cube.volume() == 1.0 + cube = Cuboid([1.0, 2.0, 3.0]) + assert cube.volume() == 6.0 return - a = (1./3) ** .5 + a = (1.0 / 3) ** 0.5 v0 = [a, a, 0] v1 = [-a, a, 0] v2 = [0, 0, a] cube = Cuboid([v0, v1, v2]) - assert cube.volume() == 1. + assert cube.volume() == 1.0 def test_origin(): - cube = Cuboid([1.0]*3) + cube = Cuboid([1.0] * 3) assert np.allclose(cube.origin, -0.5) cube.origin = 1 assert np.allclose(cube.origin, 1) def test_within1(): - cube = Cuboid([1.0]*3) - assert not cube.within([-1.]*3).any() - assert not cube.within([[-1.]*3, [-1., 0.5, 0.2]]).any() - assert cube.within([[-1.]*3, - [-1., 0.5, 0.2], - [.1, 0.5, 0.2]]).any() + cube = Cuboid([1.0] * 3) + assert not cube.within([-1.0] * 3).any() + assert not cube.within([[-1.0] * 3, [-1.0, 0.5, 0.2]]).any() + assert cube.within([[-1.0] * 3, [-1.0, 0.5, 0.2], [0.1, 0.5, 0.2]]).any() def test_within_index1(): - cube = Cuboid([1.0]*3) - assert cube.within_index([-1.]*3).size == 0 - assert cube.within_index([[-1.]*3, [-1., 0.5, 0.2]]).size == 0 - assert (cube.within_index([[-1.]*3, - [-1., 0.5, 0.2], - [.1, 0.5, 0.2]]) == [0, 1, 2]).any() + cube = Cuboid([1.0] * 3) + assert cube.within_index([-1.0] * 3).size == 0 + assert cube.within_index([[-1.0] * 3, [-1.0, 0.5, 0.2]]).size == 0 + assert ( + cube.within_index([[-1.0] * 3, [-1.0, 0.5, 0.2], [0.1, 0.5, 0.2]]) == [0, 1, 2] + ).any() diff --git a/src/sisl/shape/tests/test_shape.py b/src/sisl/shape/tests/test_shape.py index 2c88ecba79..95bab97ee9 100644 --- a/src/sisl/shape/tests/test_shape.py +++ b/src/sisl/shape/tests/test_shape.py @@ -12,8 +12,8 @@ def test_binary_op(): - e = Ellipsoid(1.) - s = Sphere(1.) + e = Ellipsoid(1.0) + s = Sphere(1.0) new = e + s str(new) @@ -22,13 +22,13 @@ def test_binary_op(): new = new & (new | e) ^ s new.center - assert new.volume() < 0. + assert new.volume() < 0.0 str(new) def test_null(): null = NullShape() - assert null.volume() == 0. + assert null.volume() == 0.0 assert len(null.within_index(np.random.rand(1000, 3))) == 0 assert null.to.Ellipsoid().volume() < 1e-64 @@ -37,8 +37,8 @@ def test_null(): def test_binary_op_within(): - e = Ellipsoid(.5) - c = Cube(1.) + e = Ellipsoid(0.5) + c = Cube(1.0) xc = [0.499] * 3 new = e + c @@ -46,13 +46,13 @@ def test_binary_op_within(): assert c.within(xc) assert new.within(xc) - xe = [0.] * 3 + xe = [0.0] * 3 new = e - c assert e.within(xe) assert c.within(xe) assert not new.within(xe) - new = (e & c) + new = e & c assert not e.within(xc) assert c.within(xc) assert not new.within(xc) @@ -61,7 +61,7 @@ def test_binary_op_within(): assert new.within(xe) # e ^ c == c ^ e - new = (e ^ c) + new = e ^ c assert not e.within(xc) assert c.within(xc) assert new.within(xc) @@ -69,16 +69,16 @@ def test_binary_op_within(): assert c.within(xe) assert not new.within(xe) - new = (c ^ e) + new = c ^ e assert new.within(xc) assert not new.within(xe) def test_binary_op_toSphere(): - e = Ellipsoid(.5) - c = Cube(1.) + e = Ellipsoid(0.5) + c = Cube(1.0) - r = 0.5 * 3 ** .5 + r = 0.5 * 3**0.5 new = e + c assert new.to.Sphere().radius.max() == pytest.approx(r) @@ -86,47 +86,47 @@ def test_binary_op_toSphere(): assert new.to.Sphere().radius.max() == pytest.approx(r) # with the AND operator we can reduce to smallest enclosed sphere - new = (e & c) + new = e & c assert new.to.Sphere().radius.max() == pytest.approx(0.5) # e ^ c == c ^ e - new = (e ^ c) + new = e ^ c assert new.to.Sphere().radius.max() == pytest.approx(r) - new = (c ^ e) + new = c ^ e assert new.to.Sphere().radius.max() == pytest.approx(r) - new = (c ^ e) + new = c ^ e assert new.scale(2).to.Sphere().radius.max() == pytest.approx(r * 2) def test_toSphere_and(): - left = Sphere(1.) - right = Sphere(1., center=[0.6] * 3) + left = Sphere(1.0) + right = Sphere(1.0, center=[0.6] * 3) new = left & right s = new.to.Sphere() - assert s.radius.max() < .9 + assert s.radius.max() < 0.9 - left = Sphere(2.) - right = Sphere(1., center=[0.5] * 3) + left = Sphere(2.0) + right = Sphere(1.0, center=[0.5] * 3) new = left & right s = new.to.Sphere() - assert s.radius.max() == pytest.approx(1.) + assert s.radius.max() == pytest.approx(1.0) - left = Sphere(2., center=[10, 10, 10]) - right = Sphere(1., center=[10.5] * 3) + left = Sphere(2.0, center=[10, 10, 10]) + right = Sphere(1.0, center=[10.5] * 3) new = left & right s2 = new.to.Sphere() - assert s2.radius.max() == pytest.approx(1.) + assert s2.radius.max() == pytest.approx(1.0) # Assert it also works for displaced centers assert np.allclose(s.radius, s2.radius) assert np.allclose(s.center, s2.center - 10) - left = Sphere(2.) - right = Sphere(1., center=[10.5] * 3) + left = Sphere(2.0) + right = Sphere(1.0, center=[10.5] * 3) new = left & right s = new.to.Sphere() @@ -134,32 +134,32 @@ def test_toSphere_and(): def test_toEllipsoid_and(): - left = Ellipsoid(1.) - right = Ellipsoid(1., center=[0.6] * 3) + left = Ellipsoid(1.0) + right = Ellipsoid(1.0, center=[0.6] * 3) new = left & right s = new.to.Ellipsoid() - assert s.radius.max() < .9 + assert s.radius.max() < 0.9 - left = Ellipsoid(2.) - right = Ellipsoid(1., center=[0.5] * 3) + left = Ellipsoid(2.0) + right = Ellipsoid(1.0, center=[0.5] * 3) new = left & right s = new.to.Ellipsoid() - assert s.radius.max() == pytest.approx(1.) + assert s.radius.max() == pytest.approx(1.0) - left = Ellipsoid(2., center=[10, 10, 10]) - right = Ellipsoid(1., center=[10.5] * 3) + left = Ellipsoid(2.0, center=[10, 10, 10]) + right = Ellipsoid(1.0, center=[10.5] * 3) new = left & right s2 = new.to.Ellipsoid() - assert s2.radius.max() == pytest.approx(1.) + assert s2.radius.max() == pytest.approx(1.0) # Assert it also works for displaced centers assert np.allclose(s.radius, s2.radius) assert np.allclose(s.center, s2.center - 10) - left = Ellipsoid(2.) - right = Ellipsoid(1., center=[10.5] * 3) + left = Ellipsoid(2.0) + right = Ellipsoid(1.0, center=[10.5] * 3) new = left & right s = new.to.Ellipsoid() @@ -167,22 +167,22 @@ def test_toEllipsoid_and(): def test_toCuboid_and(): - left = Cuboid(1.) - right = Cuboid(1., center=[0.6] * 3) + left = Cuboid(1.0) + right = Cuboid(1.0, center=[0.6] * 3) new = left & right s = new.to.Cuboid() - assert s.edge_length.max() < .9 * 2 + assert s.edge_length.max() < 0.9 * 2 - left = Cuboid(2.) - right = Cuboid(1., center=[0.5] * 3) + left = Cuboid(2.0) + right = Cuboid(1.0, center=[0.5] * 3) new = left & right s = new.to.Cuboid() assert s.edge_length.max() >= 1.5 - left = Cuboid(2., center=[10, 10, 10]) - right = Cuboid(1., center=[10.5] * 3) + left = Cuboid(2.0, center=[10, 10, 10]) + right = Cuboid(1.0, center=[10.5] * 3) new = left & right s2 = new.to.Cuboid() @@ -191,8 +191,8 @@ def test_toCuboid_and(): assert np.allclose(s.edge_length, s2.edge_length) assert np.allclose(s.center, s2.center - 10) - left = Cuboid(2.) - right = Cuboid(1., center=[10.5] * 3) + left = Cuboid(2.0) + right = Cuboid(1.0, center=[10.5] * 3) new = left & right s = new.to.Cuboid() diff --git a/src/sisl/sparse.py b/src/sisl/sparse.py index 2392ae92ee..b3be87745b 100644 --- a/src/sisl/sparse.py +++ b/src/sisl/sparse.py @@ -5,6 +5,7 @@ from numbers import Integral import numpy as np + # To speed up the _extend algorithm we limit lookups from numpy import all as np_all from numpy import allclose @@ -52,11 +53,11 @@ # Although this re-implements the CSR in scipy.sparse.csr_matrix # we use it slightly differently and thus require this new sparse pattern. -__all__ = ['SparseCSR', 'ispmatrix', 'ispmatrixd'] +__all__ = ["SparseCSR", "ispmatrix", "ispmatrixd"] def _ncol_to_indptr(ncol): - """ Convert the ncol array into a pointer array """ + """Convert the ncol array into a pointer array""" ptr = _a.emptyi(ncol.size + 1) ptr[0] = 0 _a.cumsumi(ncol, out=ptr[1:]) @@ -116,14 +117,13 @@ class SparseCSR(NDArrayOperatorsMixin): initial total number of non-zero elements This quantity has precedence over `nnzpr` """ + # We don't really need slots, but it is useful # to keep a good overview of which variables are present - __slots__ = ("_shape", "_ns", "_finalized", - "_nnz", "ptr", "ncol", "col", "_D") + __slots__ = ("_shape", "_ns", "_finalized", "_nnz", "ptr", "ncol", "col", "_D") - def __init__(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, - **kwargs): - """ Initialize a new sparse CSR matrix """ + def __init__(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, **kwargs): + """Initialize a new sparse CSR matrix""" # step size in sparse elements # If there isn't enough room for adding @@ -137,22 +137,23 @@ def __init__(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, # input sparse matrix. arg1 = arg1.tocsr() # Default shape to the CSR matrix - kwargs['shape'] = kwargs.get('shape', arg1.shape) - self.__init__((arg1.data, arg1.indices, arg1.indptr), - dim=dim, dtype=dtype, **kwargs) + kwargs["shape"] = kwargs.get("shape", arg1.shape) + self.__init__( + (arg1.data, arg1.indices, arg1.indptr), dim=dim, dtype=dtype, **kwargs + ) elif isinstance(arg1, (tuple, list)): - if isinstance(arg1[0], Integral): - self.__init_shape(arg1, dim=dim, dtype=dtype, - nnzpr=nnzpr, nnz=nnz, - **kwargs) + self.__init_shape( + arg1, dim=dim, dtype=dtype, nnzpr=nnzpr, nnz=nnz, **kwargs + ) elif len(arg1) != 3: - raise ValueError(self.__class__.__name__ + ' sparse array *must* be created ' - 'with data, indices, indptr') + raise ValueError( + self.__class__.__name__ + " sparse array *must* be created " + "with data, indices, indptr" + ) else: - # Correct dimension according to passed array if len(arg1[0].shape) == 2: dim = max(dim, arg1[0].shape[1]) @@ -162,16 +163,15 @@ def __init__(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, dtype = arg1[0].dtype # The first *must* be some sort of array - if 'shape' in kwargs: - shape = kwargs['shape'] + if "shape" in kwargs: + shape = kwargs["shape"] else: - M = len(arg1[2])-1 + M = len(arg1[2]) - 1 N = ((np.amax(arg1[1]) // M) + 1) * M shape = (M, N) - self.__init_shape(shape, dim=dim, dtype=dtype, - nnz=1, **kwargs) + self.__init_shape(shape, dim=dim, dtype=dtype, nnz=1, **kwargs) # Copy data to the arrays self.ptr = arg1[2].astype(int32, copy=False) @@ -193,15 +193,16 @@ def __init__(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, else: self._D[:, 0] = arg1[0] - def __init_shape(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, - **kwargs): - + def __init_shape(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, **kwargs): # The shape of the data... if len(arg1) == 2: # extend to extra dimension arg1 = arg1 + (dim,) elif len(arg1) != 3: - raise ValueError(self.__class__.__name__ + " unrecognized shape input, either a 2-tuple or 3-tuple is required") + raise ValueError( + self.__class__.__name__ + + " unrecognized shape input, either a 2-tuple or 3-tuple is required" + ) # Set default dtype if dtype is None: @@ -210,7 +211,10 @@ def __init_shape(self, arg1, dim=1, dtype=None, nnzpr=20, nnz=None, # unpack size and check the sizes are "physical" M, N, K = arg1 if M <= 0 or N <= 0 or K <= 0: - raise ValueError(self.__class__.__name__ + f" invalid size of sparse matrix, one of the dimensions is zero: M={M}, N={N}, K={K}") + raise ValueError( + self.__class__.__name__ + + f" invalid size of sparse matrix, one of the dimensions is zero: M={M}, N={N}, K={K}" + ) # Store shape self._shape = (M, N, K) @@ -276,7 +280,7 @@ def sparsity_union(cls, *spmats, dtype=None, dim=None, value=0): shape = shape2 + (dim,) elif len(spmats[0].shape) == 3: shape = shape2 + (spmats[0].shape[2],) - else: # csr_matrix + else: # csr_matrix shape = shape2 + (1,) if dtype is None: @@ -290,11 +294,13 @@ def sparsity_union(cls, *spmats, dtype=None, dim=None, value=0): row_cols = [] for mat in spmats: if isinstance(mat, SparseCSR): - row_cols.append(mat.col[mat.ptr[row]:mat.ptr[row] + mat.ncol[row]]) + row_cols.append( + mat.col[mat.ptr[row] : mat.ptr[row] + mat.ncol[row]] + ) else: # we have to ensure it is a csr matrix mat = mat.tocsr() - row_cols.append(mat.indices[mat.indptr[row]:mat.indptr[row+1]]) + row_cols.append(mat.indices[mat.indptr[row] : mat.indptr[row + 1]]) out_col.append(np.unique(concatenate(row_cols))) # Put into the output out.ncol = _a.arrayi([len(cols) for cols in out_col]) @@ -305,7 +311,7 @@ def sparsity_union(cls, *spmats, dtype=None, dim=None, value=0): return out def diagonal(self): - r""" Return the diagonal elements from the matrix """ + r"""Return the diagonal elements from the matrix""" # get the diagonal components diag = np.zeros([self.shape[0], self.shape[2]], dtype=self.dtype) @@ -328,7 +334,7 @@ def diagonal(self): return diag def diags(self, diagonals, offsets=0, dim=None, dtype=None): - """ Create a `SparseCSR` with diagonal elements with the same shape as the routine + """Create a `SparseCSR` with diagonal elements with the same shape as the routine Parameters ---------- @@ -366,7 +372,7 @@ def diags(self, diagonals, offsets=0, dim=None, dtype=None): return D def empty(self, keep_nnz=False): - """ Delete all sparse information from the sparsity pattern + """Delete all sparse information from the sparsity pattern Essentially this deletes all entries. @@ -378,7 +384,7 @@ def empty(self, keep_nnz=False): This may be advantagegous when re-constructing a new sparse matrix from an old sparse matrix """ - self._D[:, :] = 0. + self._D[:, :] = 0.0 if not keep_nnz: self._finalized = False @@ -391,45 +397,45 @@ def empty(self, keep_nnz=False): @property def shape(self): - """ The shape of the sparse matrix """ + """The shape of the sparse matrix""" return self._shape @property def dim(self): - """ The extra dimensionality of the sparse matrix (elements per matrix element) """ + """The extra dimensionality of the sparse matrix (elements per matrix element)""" return self.shape[2] @property def data(self): - """ Data contained in the sparse matrix (numpy array of elements) """ + """Data contained in the sparse matrix (numpy array of elements)""" return self._D @property def dtype(self): - """ The data-type in the sparse matrix """ + """The data-type in the sparse matrix""" return self._D.dtype @property def dkind(self): - """ The data-type in the sparse matrix (in str) """ + """The data-type in the sparse matrix (in str)""" return np.dtype(self._D.dtype).kind @property def nnz(self): - """ Number of non-zero elements in the sparse matrix """ + """Number of non-zero elements in the sparse matrix""" return self._nnz def __len__(self): - """ Number of rows in the sparse matrix """ + """Number of rows in the sparse matrix""" return self.shape[0] @property def finalized(self): - """ Whether the contained data is finalized and non-used elements have been removed """ + """Whether the contained data is finalized and non-used elements have been removed""" return self._finalized def finalize(self, sort=True): - """ Finalizes the sparse matrix by removing all non-set elements + """Finalizes the sparse matrix by removing all non-set elements One may still interact with the sparse matrix as one would previously. @@ -476,30 +482,38 @@ def finalize(self, sort=True): duplicates = np.logical_and(diff(col) == 0, diff(row) == 0).nonzero()[0] if len(duplicates) > 0: - raise SislError('You cannot have two elements between the same ' + - f'i,j index (i={row[duplicates]}, j={col[duplicates]})') + raise SislError( + "You cannot have two elements between the same " + + f"i,j index (i={row[duplicates]}, j={col[duplicates]})" + ) else: for r in range(self.shape[0]): ptr1 = ptr[r] - ptr2 = ptr[r+1] + ptr2 = ptr[r + 1] if unique(col[ptr1:ptr2]).shape[0] != ptr2 - ptr1: - raise SislError('You cannot have two elements between the same ' + - f'i,j index (i={r}), something has went terribly wrong.') + raise SislError( + "You cannot have two elements between the same " + + f"i,j index (i={r}), something has went terribly wrong." + ) if len(col) != self.nnz: - raise SislError('Final size in the sparse matrix finalization went wrong.') # pragma: no cover + raise SislError( + "Final size in the sparse matrix finalization went wrong." + ) # pragma: no cover # Check that all column indices are within the expected shape if np_any(self.shape[1] <= self.col): - warn("Sparse matrix contains column indices outside the shape " - "of the matrix. Data may not represent what is expected!") + warn( + "Sparse matrix contains column indices outside the shape " + "of the matrix. Data may not represent what is expected!" + ) # Signal that we indeed have finalized the data self._finalized = sort def edges(self, row, exclude=None): - """ Retrieve edges (connections) of a given `row` or list of `row`'s + """Retrieve edges (connections) of a given `row` or list of `row`'s The returned edges are unique and sorted (see `numpy.unique`). @@ -529,7 +543,7 @@ def edges(self, row, exclude=None): return edges def delete_columns(self, columns, keep_shape=False): - """ Delete all columns in `columns` (in-place action) + """Delete all columns in `columns` (in-place action) Parameters ---------- @@ -612,7 +626,7 @@ def delete_columns(self, columns, keep_shape=False): self._shape = tuple(shape) def _clean_columns(self): - """ Remove all intrinsic columns that are not defined in the sparse matrix """ + """Remove all intrinsic columns that are not defined in the sparse matrix""" # Grab pointers ptr = self.ptr ncol = self.ncol @@ -647,7 +661,7 @@ def _clean_columns(self): # it will still be def translate_columns(self, old, new, rows=None, clean=True): - """ Takes all `old` columns and translates them to `new`. + """Takes all `old` columns and translates them to `new`. Parameters ---------- @@ -664,15 +678,19 @@ def translate_columns(self, old, new, rows=None, clean=True): new = _a.asarrayi(new) if len(old) != len(new): - raise ValueError(f"{self.__class__.__name__}.translate_columns requires input and output columns with " - "equal length") + raise ValueError( + f"{self.__class__.__name__}.translate_columns requires input and output columns with " + "equal length" + ) if allclose(old, new): # No need to translate anything... return if np_any(old >= self.shape[1]): - raise ValueError(f"{self.__class__.__name__}.translate_columns has non-existing old column values") + raise ValueError( + f"{self.__class__.__name__}.translate_columns has non-existing old column values" + ) # Now do the translation pvt = _a.arangei(self.shape[1]) @@ -693,7 +711,7 @@ def translate_columns(self, old, new, rows=None, clean=True): self._clean_columns() def scale_columns(self, cols, scale, rows=None): - r""" Scale all values with certain column values with a number + r"""Scale all values with certain column values with a number This will multiply all values with certain column values with `scale` @@ -715,7 +733,9 @@ def scale_columns(self, cols, scale, rows=None): cols = _a.asarrayi(cols) if np_any(cols >= self.shape[1]): - raise ValueError(f"{self.__class__.__name__}.scale_columns has non-existing old column values") + raise ValueError( + f"{self.__class__.__name__}.scale_columns has non-existing old column values" + ) # Find indices if rows is None: @@ -728,11 +748,11 @@ def scale_columns(self, cols, scale, rows=None): self._D[idx[scale_idx]] *= scale def todense(self): - """ Return a dense `numpy.ndarray` which has 3 dimensions (self.shape) """ + """Return a dense `numpy.ndarray` which has 3 dimensions (self.shape)""" return sparse_dense(self) def spsame(self, other): - """ Check whether two sparse matrices have the same non-zero elements + """Check whether two sparse matrices have the same non-zero elements Parameters ---------- @@ -758,13 +778,20 @@ def spsame(self, other): return False for r in range(self.shape[0]): - if len(intersect1d(scol[sptr[r]:sptr[r]+sncol[r]], - ocol[optr[r]:optr[r]+oncol[r]])) != sncol[r]: + if ( + len( + intersect1d( + scol[sptr[r] : sptr[r] + sncol[r]], + ocol[optr[r] : optr[r] + oncol[r]], + ) + ) + != sncol[r] + ): return False return True def align(self, other): - """ Aligns this sparse matrix with the sparse elements of the other sparse matrix + """Aligns this sparse matrix with the sparse elements of the other sparse matrix Routine for ensuring that all non-zero elements in `other` are also in this object. @@ -780,7 +807,7 @@ def align(self, other): """ if self.shape[:2] != other.shape[:2]: - raise ValueError('Aligning two sparse matrices requires same shapes') + raise ValueError("Aligning two sparse matrices requires same shapes") sptr = self.ptr sncol = self.ncol @@ -795,17 +822,17 @@ def align(self, other): on = oncol[r] if sn == 0: - self._extend(r, ocol[op:op+on], False) + self._extend(r, ocol[op : op + on], False) continue sp = sptr[r] - adds = setdiff1d(ocol[op:op+on], scol[sp:sp+sn]) + adds = setdiff1d(ocol[op : op + on], scol[sp : sp + sn]) if len(adds) > 0: # simply extend the elements self._extend(r, adds, False) def iter_nnz(self, row=None): - """ Iterations of the non-zero elements, returns a tuple of row and column with non-zero elements + """Iterations of the non-zero elements, returns a tuple of row and column with non-zero elements An iterator returning the current row index and the corresponding column index. @@ -827,20 +854,20 @@ def iter_nnz(self, row=None): for r in range(self.shape[0]): n = self.ncol[r] ptr = self.ptr[r] - for c in self.col[ptr:ptr+n]: + for c in self.col[ptr : ptr + n]: yield r, c else: for r in _a.asarrayi(row).ravel(): n = self.ncol[r] ptr = self.ptr[r] - for c in self.col[ptr:ptr+n]: + for c in self.col[ptr : ptr + n]: yield r, c # Define default iterator __iter__ = iter_nnz def _slice2list(self, slc, axis): - """ Convert a slice to a list depending on the provided details """ + """Convert a slice to a list depending on the provided details""" if not isinstance(slc, slice): return slc @@ -851,7 +878,7 @@ def _slice2list(self, slc, axis): return range(idx[0], idx[1], idx[2]) def _extend(self, i, j, ret_indices=True): - """ Extends the sparsity pattern to retain elements `j` in row `i` + """Extends the sparsity pattern to retain elements `j` in row `i` Parameters ---------- @@ -878,7 +905,7 @@ def _extend(self, i, j, ret_indices=True): raise IndexError(f"row index is out-of-bounds {i} : {self.shape[0]}") i1 = int(i) + 1 # We skip this check and let sisl die if wrong input is given... - #if not isinstance(i, Integral): + # if not isinstance(i, Integral): # raise ValueError("Retrieving/Setting elements in a sparse matrix" # " must only be performed at one row-element at a time.\n" # "However, multiple columns at a time are allowed.") @@ -901,12 +928,12 @@ def _extend(self, i, j, ret_indices=True): # we first find which values are _not_ in the sparse # matrix if ncol_i > 0: - # Checks whether any non-zero elements are # already in the sparse pattern # If so we remove those from the j - new_j = j[in1d(j, col[ptr_i:ptr_i+ncol_i], - invert=True, assume_unique=True)] + new_j = j[ + in1d(j, col[ptr_i : ptr_i + ncol_i], invert=True, assume_unique=True) + ] else: new_j = j @@ -922,7 +949,7 @@ def _extend(self, i, j, ret_indices=True): new_nnz = new_n - int(ptr[i1]) + ncol_ptr_i if new_nnz > 0: - #print(f"new_nnz {i} : {new_nnz}") + # print(f"new_nnz {i} : {new_nnz}") # Ensure that it is not-set as finalized # There is no need to set it all the time. # Simply because the first call to finalize @@ -945,8 +972,9 @@ def _extend(self, i, j, ret_indices=True): # Insert zero data in the data array # We use `zeros` as then one may set each dimension # individually... - self._D = insert(self._D, ptr[i1], - zeros([ns, self.shape[2]], self._D.dtype), axis=0) + self._D = insert( + self._D, ptr[i1], zeros([ns, self.shape[2]], self._D.dtype), axis=0 + ) # Lastly, shift all pointers above this row to account for the # new non-zero elements @@ -958,7 +986,7 @@ def _extend(self, i, j, ret_indices=True): # assign the column indices for the new entries # NOTE that this may not assign them in the order # of entry as new_j is sorted and thus new_j != j - col[ncol_ptr_i:ncol_ptr_i+new_n] = new_j + col[ncol_ptr_i : ncol_ptr_i + new_n] = new_j # Step the size of the stored non-zero elements self.ncol[i] += int32(new_n) @@ -976,7 +1004,7 @@ def _extend(self, i, j, ret_indices=True): return indices(col[ptr_i:ncol_ptr_i], j, ptr_i) def _extend_empty(self, i, n): - """ Extends the sparsity pattern with `n` elements in row `i` + """Extends the sparsity pattern with `n` elements in row `i` Parameters ---------- @@ -991,7 +1019,7 @@ def _extend_empty(self, i, n): for indices out of bounds """ if i < 0 or i >= self.shape[0]: - raise IndexError('row index is out-of-bounds') + raise IndexError("row index is out-of-bounds") # fast reference i1 = int(i) + 1 @@ -1005,20 +1033,23 @@ def _extend_empty(self, i, n): # Insert new empty elements in the column index # after the column - self.col = insert(self.col, self.ptr[i] + self.ncol[i], full(n, -1, self.col.dtype)) + self.col = insert( + self.col, self.ptr[i] + self.ncol[i], full(n, -1, self.col.dtype) + ) # Insert zero data in the data array # We use `zeros` as then one may set each dimension # individually... - self._D = insert(self._D, self.ptr[i1], - zeros([n, self.shape[2]], self._D.dtype), axis=0) + self._D = insert( + self._D, self.ptr[i1], zeros([n, self.shape[2]], self._D.dtype), axis=0 + ) # Lastly, shift all pointers above this row to account for the # new non-zero elements self.ptr[i1:] += int32(n) def _get(self, i, j): - """ Retrieves the data pointer arrays of the elements, if it is non-existing, it will return ``-1`` + """Retrieves the data pointer arrays of the elements, if it is non-existing, it will return ``-1`` Parameters ---------- @@ -1039,11 +1070,11 @@ def _get(self, i, j): ptr = self.ptr[i] if j.ndim == 0: - return indices(self.col[ptr:ptr+self.ncol[i]], j.ravel(), ptr)[0] - return indices(self.col[ptr:ptr+self.ncol[i]], j, ptr) + return indices(self.col[ptr : ptr + self.ncol[i]], j.ravel(), ptr)[0] + return indices(self.col[ptr : ptr + self.ncol[i]], j, ptr) def _get_only(self, i, j): - """ Retrieves the data pointer arrays of the elements, only return elements in the sparse array + """Retrieves the data pointer arrays of the elements, only return elements in the sparse array Parameters ---------- @@ -1063,10 +1094,10 @@ def _get_only(self, i, j): # Make it a little easier ptr = self.ptr[i] - return indices_only(self.col[ptr:ptr+self.ncol[i]], j) + ptr + return indices_only(self.col[ptr : ptr + self.ncol[i]], j) + ptr def __delitem__(self, key): - """ Remove items from the sparse patterns """ + """Remove items from the sparse patterns""" # Get indices of sparse data (-1 if non-existing) key = list(key) key[0] = self._slice2list(key[0], 0) @@ -1115,7 +1146,7 @@ def __delitem__(self, key): self._nnz -= n_index def __getitem__(self, key): - """ Intrinsic sparse matrix retrieval of a non-zero element """ + """Intrinsic sparse matrix retrieval of a non-zero element""" # Get indices of sparse data (-1 if non-existing) get_idx = self._get(key[0], key[1]) @@ -1132,14 +1163,12 @@ def __getitem__(self, key): # Check which data to retrieve if len(key) > 2: - # user requests a specific element # get dimension retrieved r = zeros(n, dtype=self._D.dtype) r[ret_idx] = self._D[get_idx, key[2]] else: - # user request all stored data s = self.shape[2] @@ -1157,7 +1186,7 @@ def __getitem__(self, key): return r def __setitem__(self, key, data): - """ Intrinsic sparse matrix assignment of the item. + """Intrinsic sparse matrix assignment of the item. It will only allow to set the data in the sparse matrix if the dimensions match. @@ -1182,7 +1211,6 @@ def __setitem__(self, key, data): i = key[0] j = key[1] if isinstance(i, (list, ndarray)) and isinstance(j, (list, ndarray)): - # Create a b-cast object to iterate # Note that this does not do the actual b-casting and thus # we can iterate and operate as though it was an actual array @@ -1208,10 +1236,18 @@ def __setitem__(self, key, data): # *only* if # do a sanity check if self.dim > 1 and len(key) == 2: - raise ValueError("could not broadcast input array from shape {} into shape {}".format(data.shape, ij.shape + (self.dim,))) + raise ValueError( + "could not broadcast input array from shape {} into shape {}".format( + data.shape, ij.shape + (self.dim,) + ) + ) if len(key) == 3: if atleast_1d(key[2]).size > 1: - raise ValueError("could not broadcast input array from shape {} into shape {}".format(data.shape, ij.shape + (atleast_1d(key[2]).size,))) + raise ValueError( + "could not broadcast input array from shape {} into shape {}".format( + data.shape, ij.shape + (atleast_1d(key[2]).size,) + ) + ) # flatten data data.shape = (-1,) # ij.ndim == 1 @@ -1220,7 +1256,9 @@ def __setitem__(self, key, data): # if ij.size != data.shape[0] an error should occur down below elif data.ndim == 3: if ij.ndim != 2: - raise ValueError("could not broadcast input array from 3 dimensions into 2") + raise ValueError( + "could not broadcast input array from 3 dimensions into 2" + ) data.shape = (-1, data.shape[2]) # Now we need to figure out the final dimension and how to @@ -1272,12 +1310,12 @@ def __setitem__(self, key, data): self._D[index, :] = data[:, :] def __contains__(self, key): - """ Check whether a sparse index is non-zero """ + """Check whether a sparse index is non-zero""" # Get indices of sparse data (-1 if non-existing) return np_all(self._get(key[0], key[1]) >= 0) def nonzero(self, rows=None, only_cols=False): - """ Row and column indices where non-zero elements exists + """Row and column indices where non-zero elements exists Parameters ---------- @@ -1308,8 +1346,8 @@ def nonzero(self, rows=None, only_cols=False): return cols return rows, cols - def eliminate_zeros(self, atol=0.): - """ Remove all zero elememts from the sparse matrix + def eliminate_zeros(self, atol=0.0): + """Remove all zero elememts from the sparse matrix This is an *in-place* operation @@ -1336,7 +1374,6 @@ def eliminate_zeros(self, atol=0.): return for r in range(self.shape[0]): - # Create short-hand slice idx = arangei(ptr[r], ptr[r] + ncol[r]) @@ -1349,7 +1386,7 @@ def eliminate_zeros(self, atol=0.): del self[r, col[idx[C0]]] def copy(self, dims=None, dtype=None): - """ A deepcopy of the sparse matrix + """A deepcopy of the sparse matrix Parameters ---------- @@ -1379,8 +1416,8 @@ def copy(self, dims=None, dtype=None): # The default sizes are not passed # Hence we *must* copy the arrays # directly - copyto(new.ptr, self.ptr, casting='same_kind') - copyto(new.ncol, self.ncol, casting='same_kind') + copyto(new.ptr, self.ptr, casting="same_kind") + copyto(new.ncol, self.ncol, casting="same_kind") new.col = self.col.copy() new._nnz = self.nnz @@ -1394,7 +1431,7 @@ def copy(self, dims=None, dtype=None): return new def tocsr(self, dim=0, **kwargs): - """ Convert dimension `dim` into a :class:`~scipy.sparse.csr_matrix` format + """Convert dimension `dim` into a :class:`~scipy.sparse.csr_matrix` format Parameters ---------- @@ -1406,20 +1443,29 @@ def tocsr(self, dim=0, **kwargs): shape = self.shape[:2] if self.finalized: # Easy case... - return csr_matrix((self._D[:, dim].copy(), - self.col.astype(int32, copy=True), self.ptr.astype(int32, copy=True)), - shape=shape, **kwargs) + return csr_matrix( + ( + self._D[:, dim].copy(), + self.col.astype(int32, copy=True), + self.ptr.astype(int32, copy=True), + ), + shape=shape, + **kwargs, + ) # Use array_arange idx = array_arange(self.ptr[:-1], n=self.ncol) # create new pointer ptr = _ncol_to_indptr(self.ncol) - return csr_matrix((self._D[idx, dim].copy(), self.col[idx], ptr.astype(int32, copy=False)), - shape=shape, **kwargs) + return csr_matrix( + (self._D[idx, dim].copy(), self.col[idx], ptr.astype(int32, copy=False)), + shape=shape, + **kwargs, + ) def transform(self, matrix, dtype=None): - r""" Apply a linear transformation :math:`R^n \rightarrow R^m` to the :math:`n`-dimensional elements of the sparse matrix + r"""Apply a linear transformation :math:`R^n \rightarrow R^m` to the :math:`n`-dimensional elements of the sparse matrix Notes ----- @@ -1440,9 +1486,11 @@ def transform(self, matrix, dtype=None): dtype = np.find_common_type([self.dtype, matrix.dtype], []) if matrix.shape[1] != self.shape[2]: - raise ValueError(f"{self.__class__.__name__}.transform incompatible " - f"transformation matrix and spin dimensions: " - f"matrix.shape={matrix.shape} and self.spin={self.shape[2]} ; out.spin={matrix.shape[0]}") + raise ValueError( + f"{self.__class__.__name__}.transform incompatible " + f"transformation matrix and spin dimensions: " + f"matrix.shape={matrix.shape} and self.spin={self.shape[2]} ; out.spin={matrix.shape[0]}" + ) # set dimension of new sparse matrix new_dim = matrix.shape[0] @@ -1451,8 +1499,8 @@ def transform(self, matrix, dtype=None): new = self.__class__(shape, dtype=dtype, nnz=1) - copyto(new.ptr, self.ptr, casting='no') - copyto(new.ncol, self.ncol, casting='no') + copyto(new.ptr, self.ptr, casting="no") + copyto(new.ncol, self.ncol, casting="no") new.col = self.col.copy() new._nnz = self.nnz @@ -1464,7 +1512,7 @@ def transform(self, matrix, dtype=None): @classmethod def fromsp(cls, *sps, dtype=None): - """ Combine multiple single-dimension sparse matrices into one SparseCSR matrix + """Combine multiple single-dimension sparse matrices into one SparseCSR matrix The different sparse matrices need not have the same sparsity pattern. @@ -1497,15 +1545,17 @@ def fromsp(cls, *sps, dtype=None): m = m.tocsr() m.sort_indices() for r in range(out.shape[0]): - msl = slice(m.indptr[r], m.indptr[r+1]) + msl = slice(m.indptr[r], m.indptr[r + 1]) osl = slice(out.ptr[r], out.ptr[r] + out.ncol[r]) - oidx = indices(out.col[osl], m.indices[msl], osl.start, both_sorted=True) + oidx = indices( + out.col[osl], m.indices[msl], osl.start, both_sorted=True + ) out._D[oidx, im] = m.data[msl] return out def remove(self, indices): - """ Return a new sparse CSR matrix with all the indices removed + """Return a new sparse CSR matrix with all the indices removed Parameters ---------- @@ -1524,7 +1574,7 @@ def remove(self, indices): return self.sub(rindices) def sub(self, indices): - """ Create a new sparse CSR matrix with the data only for the given rows and columns + """Create a new sparse CSR matrix with the data only for the given rows and columns All rows and columns in `indices` are retained, everything else is removed. @@ -1591,8 +1641,7 @@ def sub(self, indices): # Count number of entries idx_take = col_data[1, :] >= 0 - ncol1[:] = _a.fromiteri(map(count_nonzero, - split(idx_take, ptr1[1:-1]))).ravel() + ncol1[:] = _a.fromiteri(map(count_nonzero, split(idx_take, ptr1[1:-1]))).ravel() # Convert to indices idx_take = idx_take.nonzero()[0] @@ -1612,7 +1661,7 @@ def sub(self, indices): return csr def transpose(self, sort=True): - """ Create the transposed sparse matrix + """Create the transposed sparse matrix Parameters ---------- @@ -1642,12 +1691,12 @@ def transpose(self, sort=True): # First extract the actual data ncol = self.ncol.view() if self.finalized: - #ptr = self.ptr.view() + # ptr = self.ptr.view() col = self.col.copy() D = self._D.copy() else: idx = array_arange(self.ptr[:-1], n=ncol, dtype=int32) - #ptr = _ncol_to_indptr(ncol) + # ptr = _ncol_to_indptr(ncol) col = self.col[idx] D = self._D[idx, :].copy() del idx @@ -1689,9 +1738,14 @@ def transpose(self, sort=True): return T def __str__(self): - """ Representation of the sparse matrix model """ + """Representation of the sparse matrix model""" ints = self.shape[:] + (self.nnz,) - return self.__class__.__name__ + '{{dim={2}, kind={kind},\n rows: {0}, columns: {1},\n non-zero: {3}\n}}'.format(*ints, kind=self.dkind) + return ( + self.__class__.__name__ + + "{{dim={2}, kind={kind},\n rows: {0}, columns: {1},\n non-zero: {3}\n}}".format( + *ints, kind=self.dkind + ) + ) def __repr__(self): return f"<{self.__module__}.{self.__class__.__name__} shape={self.shape}, kind={self.dkind}, nnz={self.nnz}>" @@ -1700,7 +1754,7 @@ def __repr__(self): __array_priority__ = 14 def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - #print(f"{self.__class__.__name__}.__array_ufunc__ :", ufunc, method) + # print(f"{self.__class__.__name__}.__array_ufunc__ :", ufunc, method) out = kwargs.pop("out", None) if getattr(ufunc, "signature", None) is not None: @@ -1716,13 +1770,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): elif method == "reduce": result = _ufunc_reduce(ufunc, *inputs, **kwargs) elif method == "outer": - #print("running outer") + # print("running outer") # Currently I don't know what to do here # We don't have multidimensional sparse matrices, # but perhaps that could be needed later? return NotImplemented else: - #print("running method = ", method) + # print("running method = ", method) return NotImplemented if out is None: @@ -1731,8 +1785,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): out[...] = result[...] elif isinstance(out, SparseCSR): if out.shape != result.shape: - raise ValueError(f"non-broadcastable output operand with shape {out.shape} " - "doesn't match the broadcast shape {result.shape}") + raise ValueError( + f"non-broadcastable output operand with shape {out.shape} " + "doesn't match the broadcast shape {result.shape}" + ) out._finalized = result._finalized out.ncol[:] = result.ncol[:] out.ptr[:] = result.ptr[:] @@ -1744,30 +1800,30 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return out def __getstate__(self): - """ Return dictionary with the current state (finalizing the object may reduce memory footprint) """ + """Return dictionary with the current state (finalizing the object may reduce memory footprint)""" d = { - 'shape': self._shape[:], - 'ncol': self.ncol.copy(), - 'col': self.col.copy(), - 'D': self._D.copy(), - 'finalized': self._finalized + "shape": self._shape[:], + "ncol": self.ncol.copy(), + "col": self.col.copy(), + "D": self._D.copy(), + "finalized": self._finalized, } if not self.finalized: - d['ptr'] = self.ptr.copy() + d["ptr"] = self.ptr.copy() return d def __setstate__(self, state): - """ Reset state of the object """ - self._shape = tuple(state['shape'][:]) - self.ncol = state['ncol'] - self.col = state['col'] - self._D = state['D'] + """Reset state of the object""" + self._shape = tuple(state["shape"][:]) + self.ncol = state["ncol"] + self.col = state["col"] + self._D = state["D"] self._nnz = self.ncol.sum() - self._finalized = state['finalized'] + self._finalized = state["finalized"] if self.finalized: self.ptr = _ncol_to_indptr(self.ncol) else: - self.ptr = state['ptr'] + self.ptr = state["ptr"] def _get_reduced_shape(shape): @@ -1824,7 +1880,7 @@ def _ufunc_ndarray_sp(ufunc, a, b, **kwargs): def _ufunc_sp_sp(ufunc, a, b, **kwargs): - """ Calculate ufunc on sparse matrices """ + """Calculate ufunc on sparse matrices""" if isinstance(a, tuple): a = SparseCSR.fromsp(*a) if isinstance(b, tuple): @@ -1832,8 +1888,10 @@ def _ufunc_sp_sp(ufunc, a, b, **kwargs): def accessors(mat): if isinstance(mat, SparseCSR): + def rowslice(r): return slice(mat.ptr[r], mat.ptr[r] + mat.ncol[r]) + accessors = mat.dim, mat.col, mat._D, rowslice issorted = mat.finalized else: @@ -1841,14 +1899,20 @@ def rowslice(r): # and csr_matrix.tocsr is a no-op mat = mat.tocsr() mat.sort_indices() + def rowslice(r): return slice(mat.indptr[r], mat.indptr[r + 1]) + accessors = 1, mat.indices, mat.data.reshape(-1, 1), rowslice issorted = mat.has_sorted_indices if issorted: - indexfunc = lambda ocol, matcol, offset: indices(ocol, matcol, offset, both_sorted=True) + indexfunc = lambda ocol, matcol, offset: indices( + ocol, matcol, offset, both_sorted=True + ) else: - indexfunc = lambda ocol, matcol, offset: np.searchsorted(ocol, matcol) + offset + indexfunc = ( + lambda ocol, matcol, offset: np.searchsorted(ocol, matcol) + offset + ) return accessors + (indexfunc,) adim, acol, adata, arow, afindidx = accessors(a) @@ -1862,7 +1926,7 @@ def rowslice(r): for r in range(out.shape[0]): offset = out.ptr[r] - ocol = out.col[offset:offset + out.ncol[r]] + ocol = out.col[offset : offset + out.ncol[r]] asl = arow(r) aidx = afindidx(ocol, acol[asl], offset) @@ -1880,8 +1944,9 @@ def rowslice(r): # overlapping indices if iover.size > 0: - out._D[iover, :] = ufunc(adata[asl[aover], :], - bdata[bsl[bover], :], **kwargs) + out._D[iover, :] = ufunc( + adata[asl[aover], :], bdata[bsl[bover], :], **kwargs + ) if iaonly.size > 0: # only a @@ -1894,7 +1959,6 @@ def rowslice(r): def _ufunc_call(ufunc, *in_args, **kwargs): - # first process in_args to args # by numpy-fying and checking for sparsecsr args = [] @@ -1922,7 +1986,8 @@ def spshape(arg): # but SparseCSR always have 3, so we pad with ones. return arg.shape + (1,) return arg.shape - #shape = _get_bcast_shape(*tuple(spshape(arg) for arg in args)) + + # shape = _get_bcast_shape(*tuple(spshape(arg) for arg in args)) if len(args) == 1: a = args[0] @@ -1939,27 +2004,31 @@ def spshape(arg): def _(a, b): return _ufunc(ufunc, a, b, **kwargs) + return reduce(_, args) def _ufunc_reduce(ufunc, array, axis=0, *args, **kwargs): - if "dtype" not in kwargs: kwargs["dtype"] = np.result_type(array, *args) # currently the initial argument does not work properly if the # size isn't correct - if np.asarray(kwargs.get("initial", 0.)).ndim > 1: - raise ValueError(f"{array.__class__.__name__}.{ufunc.__name__}.reduce currently does not implement initial values in different dimensions") + if np.asarray(kwargs.get("initial", 0.0)).ndim > 1: + raise ValueError( + f"{array.__class__.__name__}.{ufunc.__name__}.reduce currently does not implement initial values in different dimensions" + ) if isinstance(axis, (tuple, list, np.ndarray)): if len(axis) == 1: axis = axis[0] else: + def wrap_axis(axis): if axis < 0: return axis + len(array.shape) return axis + axis = tuple(wrap_axis(ax) for ax in axis) if axis == (0, 1) or axis == (1, 0): return ufunc.reduce(array._D, axis=0, *args, **kwargs) @@ -1983,20 +2052,26 @@ def wrap_axis(axis): out._D[:, 0] = ufunc.reduce(array._D, axis=1, *args, **kwargs) return out else: - raise ValueError(f"Unknown axis argument in ufunc.reduce call on {array.__class__.__name__}") + raise ValueError( + f"Unknown axis argument in ufunc.reduce call on {array.__class__.__name__}" + ) - ret = empty([array.shape[0], array.shape[2]], dtype=kwargs.get("dtype", array.dtype)) + ret = empty( + [array.shape[0], array.shape[2]], dtype=kwargs.get("dtype", array.dtype) + ) # Now do ufunc calculations, note that initial gets passed directly ptr = array.ptr ncol = array.ncol for r in range(array.shape[0]): - ret[r, :] = ufunc.reduce(array._D[ptr[r]:ptr[r]+ncol[r], :], axis=0, *args, **kwargs) + ret[r, :] = ufunc.reduce( + array._D[ptr[r] : ptr[r] + ncol[r], :], axis=0, *args, **kwargs + ) return ret @set_module("sisl") def ispmatrix(matrix, map_row=None, map_col=None): - """ Iterator for iterating rows and columns for non-zero elements in a `scipy.sparse.*_matrix` (or `SparseCSR`) + """Iterator for iterating rows and columns for non-zero elements in a `scipy.sparse.*_matrix` (or `SparseCSR`) If either `map_row` or `map_col` are not None the generator will only yield the unique values. @@ -2039,61 +2114,73 @@ def ispmatrix(matrix, map_row=None, map_col=None): rows[:] = False # Consider using the numpy nditer function for buffered iterations - #it = np.nditer([geom.o2a(tmp.row), geom.o2a(tmp.col % geom.no), tmp.data], + # it = np.nditer([geom.o2a(tmp.row), geom.o2a(tmp.col % geom.no), tmp.data], # flags=['buffered'], op_flags=['readonly']) if issparse(matrix) and matrix.format == "csr": for r in range(matrix.shape[0]): rr = map_row(r) - if rows[rr]: continue + if rows[rr]: + continue rows[rr] = True cols[:] = False - for ind in range(matrix.indptr[r], matrix.indptr[r+1]): + for ind in range(matrix.indptr[r], matrix.indptr[r + 1]): c = map_col(matrix.indices[ind]) - if cols[c]: continue + if cols[c]: + continue cols[c] = True yield rr, c elif issparse(matrix) and matrix.format == "lil": for r in range(matrix.shape[0]): rr = map_row(r) - if rows[rr]: continue + if rows[rr]: + continue rows[rr] = True cols[:] = False if len(matrix.rows[r]) == 0: continue for c in map_col(matrix.rows[r]): - if cols[c]: continue + if cols[c]: + continue cols[c] = True yield rr, c elif issparse(matrix) and matrix.format == "coo": - raise ValueError("mapping and unique returns are not implemented for COO matrix") + raise ValueError( + "mapping and unique returns are not implemented for COO matrix" + ) elif issparse(matrix) and matrix.format == "csc": - raise ValueError("mapping and unique returns are not implemented for CSC matrix") + raise ValueError( + "mapping and unique returns are not implemented for CSC matrix" + ) elif isinstance(matrix, SparseCSR): for r in range(matrix.shape[0]): rr = map_row(r) - if rows[rr]: continue + if rows[rr]: + continue rows[rr] = True cols[:] = False n = matrix.ncol[r] if n == 0: continue ptr = matrix.ptr[r] - for c in map_col(matrix.col[ptr:ptr+n]): - if cols[c]: continue + for c in map_col(matrix.col[ptr : ptr + n]): + if cols[c]: + continue cols[c] = True yield rr, c else: - raise NotImplementedError("The iterator for this sparse matrix has not been implemented") + raise NotImplementedError( + "The iterator for this sparse matrix has not been implemented" + ) def _ispmatrix_all(matrix): - """ Iterator for iterating rows and columns for non-zero elements in a ``scipy.sparse.*_matrix`` (or `SparseCSR`) + """Iterator for iterating rows and columns for non-zero elements in a ``scipy.sparse.*_matrix`` (or `SparseCSR`) Parameters ---------- @@ -2107,7 +2194,7 @@ def _ispmatrix_all(matrix): """ if issparse(matrix) and matrix.format == "csr": for r in range(matrix.shape[0]): - for ind in range(matrix.indptr[r], matrix.indptr[r+1]): + for ind in range(matrix.indptr[r], matrix.indptr[r + 1]): yield r, matrix.indices[ind] elif issparse(matrix) and matrix.format == "lil": @@ -2120,23 +2207,25 @@ def _ispmatrix_all(matrix): elif issparse(matrix) and matrix.format == "csc": for c in range(matrix.shape[1]): - for ind in range(matrix.indptr[c], matrix.indptr[c+1]): + for ind in range(matrix.indptr[c], matrix.indptr[c + 1]): yield matrix.indices[ind], c elif isinstance(matrix, SparseCSR): for r in range(matrix.shape[0]): n = matrix.ncol[r] ptr = matrix.ptr[r] - for c in matrix.col[ptr:ptr+n]: + for c in matrix.col[ptr : ptr + n]: yield r, c else: - raise NotImplementedError("The iterator for this sparse matrix has not been implemented") + raise NotImplementedError( + "The iterator for this sparse matrix has not been implemented" + ) @set_module("sisl") def ispmatrixd(matrix, map_row=None, map_col=None): - """ Iterator for iterating rows, columns and data for non-zero elements in a ``scipy.sparse.*_matrix`` (or `SparseCSR`) + """Iterator for iterating rows, columns and data for non-zero elements in a ``scipy.sparse.*_matrix`` (or `SparseCSR`) Parameters ---------- @@ -2160,13 +2249,13 @@ def ispmatrixd(matrix, map_row=None, map_col=None): map_col = lambda x: x # Consider using the numpy nditer function for buffered iterations - #it = np.nditer([geom.o2a(tmp.row), geom.o2a(tmp.col % geom.no), tmp.data], + # it = np.nditer([geom.o2a(tmp.row), geom.o2a(tmp.col % geom.no), tmp.data], # flags=['buffered'], op_flags=['readonly']) if issparse(matrix) and matrix.format == "csr": for r in range(matrix.shape[0]): rr = map_row(r) - for ind in range(matrix.indptr[r], matrix.indptr[r+1]): + for ind in range(matrix.indptr[r], matrix.indptr[r + 1]): yield rr, map_col(matrix.indices[ind]), matrix.data[ind] elif issparse(matrix) and matrix.format == "lil": @@ -2181,7 +2270,7 @@ def ispmatrixd(matrix, map_row=None, map_col=None): elif issparse(matrix) and matrix.format == "csc": for c in range(matrix.shape[1]): cc = map_col(c) - for ind in range(matrix.indptr[c], matrix.indptr[c+1]): + for ind in range(matrix.indptr[c], matrix.indptr[c + 1]): yield map_row(matrix.indices[ind]), cc, matrix.data[ind] elif isinstance(matrix, SparseCSR): @@ -2191,9 +2280,11 @@ def ispmatrixd(matrix, map_row=None, map_col=None): if n == 0: continue ptr = matrix.ptr[r] - sl = slice(ptr, ptr+n, None) + sl = slice(ptr, ptr + n, None) for c, d in zip(map_col(matrix.col[sl]), matrix._D[sl, :]): yield rr, c, d else: - raise NotImplementedError("The iterator for this sparse matrix has not been implemented") + raise NotImplementedError( + "The iterator for this sparse matrix has not been implemented" + ) diff --git a/src/sisl/sparse_geometry.py b/src/sisl/sparse_geometry.py index c5b0364d01..bcd883c38a 100644 --- a/src/sisl/sparse_geometry.py +++ b/src/sisl/sparse_geometry.py @@ -42,11 +42,11 @@ from .sparse import SparseCSR, _ncol_to_indptr, issparse from .utils.ranges import list2str -__all__ = ['SparseAtom', 'SparseOrbital'] +__all__ = ["SparseAtom", "SparseOrbital"] class _SparseGeometry(NDArrayOperatorsMixin): - """ Sparse object containing sparse elements for a given geometry. + """Sparse object containing sparse elements for a given geometry. This is a base class intended to be sub-classed because the sparsity information needs to be extracted from the ``_size`` attribute. @@ -61,7 +61,7 @@ class _SparseGeometry(NDArrayOperatorsMixin): """ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): - """ Create sparse object with element between orbitals """ + """Create sparse object with element between orbitals""" self._geometry = geometry # Initialize the sparsity pattern @@ -69,24 +69,24 @@ def __init__(self, geometry, dim=1, dtype=None, nnzpr=None, **kwargs): @property def geometry(self): - """ Associated geometry """ + """Associated geometry""" return self._geometry @property def _size(self): - """ The size of the sparse object """ + """The size of the sparse object""" return self.geometry.na def __len__(self): - """ Number of rows in the basis """ + """Number of rows in the basis""" return self._size def _cls_kwargs(self): - """ Custom keyword arguments when creating a new instance """ + """Custom keyword arguments when creating a new instance""" return {} def reset(self, dim=None, dtype=np.float64, nnzpr=None): - """ The sparsity pattern has all elements removed and everything is reset. + """The sparsity pattern has all elements removed and everything is reset. The object will be the same as if it had been initialized with the same geometry as it were @@ -126,11 +126,11 @@ def reset(self, dim=None, dtype=np.float64, nnzpr=None): self._def_dim = -1 def empty(self, keep_nnz=False): - """ See :meth:`~sparse.SparseCSR.empty` for details """ + """See :meth:`~sparse.SparseCSR.empty` for details""" self._csr.empty(keep_nnz) def copy(self, dtype=None): - """ A copy of this object + """A copy of this object Parameters ---------- @@ -140,39 +140,41 @@ def copy(self, dtype=None): """ if dtype is None: dtype = self.dtype - new = self.__class__(self.geometry.copy(), self.dim, dtype, 1, **self._cls_kwargs()) + new = self.__class__( + self.geometry.copy(), self.dim, dtype, 1, **self._cls_kwargs() + ) # Be sure to copy the content of the SparseCSR object new._csr = self._csr.copy(dtype=dtype) return new @property def dim(self): - """ Number of components per element """ + """Number of components per element""" return self._csr.shape[-1] @property def shape(self): - """ Shape of sparse matrix """ + """Shape of sparse matrix""" return self._csr.shape @property def dtype(self): - """ Data type of sparse elements """ + """Data type of sparse elements""" return self._csr.dtype @property def dkind(self): - """ Data type of sparse elements (in str) """ + """Data type of sparse elements (in str)""" return self._csr.dkind @property def nnz(self): - """ Number of non-zero elements """ + """Number of non-zero elements""" return self._csr.nnz def translate2uc(self, atoms: Optional[AtomsArgument] = None, axes=None): """Translates all primary atoms to the unit cell. - + With this, the coordinates of the geometry are translated to the unit cell and the supercell connections in the matrix are updated accordingly. @@ -197,14 +199,16 @@ def translate2uc(self, atoms: Optional[AtomsArgument] = None, axes=None): if axes: axes = (0, 1, 2) else: - raise ValueError("translate2uc with a bool argument can only be True to signal all axes") - + raise ValueError( + "translate2uc with a bool argument can only be True to signal all axes" + ) + # Sanitize also the atoms argument if atoms is None: ats = slice(None) else: ats = self.geometry._sanitize_atoms(atoms).ravel() - + # Get the fractional coordinates of the associated geometry fxyz = self.geometry.fxyz @@ -222,7 +226,7 @@ def translate2uc(self, atoms: Optional[AtomsArgument] = None, axes=None): def _translate_atoms_sc(self, sc_translations): """Translates atoms across supercells. - This operation results in new coordinates of the associated geometry + This operation results in new coordinates of the associated geometry and new indices for the matrix elements. Parameters @@ -237,18 +241,17 @@ def _translate_atoms_sc(self, sc_translations): """ # Make sure that supercell translations is an array of integers sc_translations = np.asarray(sc_translations, dtype=int) - + # Get the row and column of every element in the matrix rows, cols = self.nonzero() n_rows = len(self) is_atom = n_rows == self.na - # Find out the unit cell indices for the columns, and the index of the supercell # where they are currently located. This is done by dividing into the number of # columns in the unit cell. - # We will do the conversion back to supercell indices when we know their + # We will do the conversion back to supercell indices when we know their # new location after translation, and also the size of the new auxiliary supercell. sc_idx, uc_col = np.divmod(cols, n_rows) @@ -262,52 +265,52 @@ def _translate_atoms_sc(self, sc_translations): # orbital indices at_row = self.o2a(rows) at_col = self.o2a(cols) % self.na - + # Get the supercell indices of the original positions. isc = self.sc_off[sc_idx] - + # We are going to now displace the supercell index of the connections # according to how the two orbitals involved have moved. We store the # result in the same array just to avoid using more memory. isc += sc_translations[at_row] - sc_translations[at_col] - + # It is possible that once we discover the new locations of the connections # we find out that we need a bigger or smaller auxiliary supercell. Find out # the size of the new auxiliary supercell. new_nsc = np.max(abs(isc), axis=0) * 2 + 1 - + # Create a new geometry object with the new auxiliary supercell size. new_geometry = self.geometry.copy() new_geometry.set_nsc(new_nsc) - + # Update the coordinates of the geometry, according to the cell # displacements. new_geometry.xyz = new_geometry.xyz + sc_translations @ new_geometry.cell - + # Find out supercell indices in this new auxiliary supercell new_sc = new_geometry.isc_off[isc[:, 0], isc[:, 1], isc[:, 2]] - + # With this, we can compute the new columns new_cols = uc_col + new_sc * n_rows - + # Build the new csr matrix, which will just be a copy of the current one # but updating the column indices. It is possible that there are column # indices that are -1, which are the placeholders for new elements. We make sure - # that we update only the indices that are not -1. - # We also need to make sure that the shape of the matrix is appropiate + # that we update only the indices that are not -1. + # We also need to make sure that the shape of the matrix is appropiate # for the size of the new auxiliary cell. new_csr = self._csr.copy() new_csr.col[new_csr.col >= 0] = new_cols new_csr._shape = (n_rows, n_rows * new_geometry.n_s, new_csr.shape[-1]) - + # Create the new SparseGeometry matrix and associate to it the csr matrix that we have built. new_matrix = self.__class__(new_geometry) new_matrix._csr = new_csr - + return new_matrix def _translate_cells(self, old, new): - """ Translates all columns in the `old` cell indices to the `new` cell indices + """Translates all columns in the `old` cell indices to the `new` cell indices Since the physical matrices are stored in a CSR form, with shape ``(no, no * n_s)`` each block of ``(no, no)`` refers to supercell matrices with an offset according to the internal @@ -327,8 +330,11 @@ def _translate_cells(self, old, new): new = _a.asarrayi(new).ravel() if len(old) != len(new): - raise ValueError(self.__class__.__name__+".translate_cells requires input and output indices with " - "equal length") + raise ValueError( + self.__class__.__name__ + + ".translate_cells requires input and output indices with " + "equal length" + ) no = self.no # Number of elements per matrix @@ -339,7 +345,7 @@ def _translate_cells(self, old, new): self._csr.translate_columns(old, new) def edges(self, atoms, exclude=None): - """ Retrieve edges (connections) for all `atoms` + """Retrieve edges (connections) for all `atoms` The returned edges are unique and sorted (see `numpy.unique`) and are returned in supercell indices (i.e. ``0 <= edge < self.geometry.na_s``). @@ -358,15 +364,15 @@ def edges(self, atoms, exclude=None): return self._csr.edges(atoms, exclude) def __str__(self): - """ Representation of the sparse model """ + """Representation of the sparse model""" s = f"{self.__class__.__name__}{{dim: {self.dim}, non-zero: {self.nnz}, kind={self.dkind}\n " - return s + str(self.geometry).replace('\n', '\n ') + "\n}" + return s + str(self.geometry).replace("\n", "\n ") + "\n}" def __repr__(self): return f"<{self.__module__}.{self.__class__.__name__} shape={self._csr.shape[:-1]}, dim={self.dim}, nnz={self.nnz}, kind={self.dkind}>" def __getattr__(self, attr): - """ Overload attributes from the hosting geometry + """Overload attributes from the hosting geometry Any attribute not found in the sparse class will be looked up in the hosting geometry. @@ -375,15 +381,15 @@ def __getattr__(self, attr): # Make the indicis behave on the contained sparse matrix def __delitem__(self, key): - """ Delete elements of the sparse elements """ + """Delete elements of the sparse elements""" del self._csr[key] def __contains__(self, key): - """ Check whether a sparse index is non-zero """ + """Check whether a sparse index is non-zero""" return key in self._csr def set_nsc(self, size, *args, **kwargs): - """ Reset the number of allowed supercells in the sparse geometry + """Reset the number of allowed supercells in the sparse geometry If one reduces the number of supercells, *any* sparse element that references the supercell will be deleted. @@ -475,7 +481,7 @@ def set_nsc(self, size, *args, **kwargs): self.geometry.set_nsc(*args, **kwargs) def transpose(self, sort=True): - """ Create the transposed sparse geometry by interchanging supercell indices + """Create the transposed sparse geometry by interchanging supercell indices Sparse geometries are (typically) relying on symmetry in the supercell picture. Thus when one transposes a sparse geometry one should *ideally* get the same @@ -529,12 +535,12 @@ def transpose(self, sort=True): # First extract the actual data ncol = csr.ncol.view() if csr.finalized: - #ptr = csr.ptr.view() + # ptr = csr.ptr.view() col = csr.col.copy() D = csr._D.copy() else: idx = array_arange(csr.ptr[:-1], n=ncol, dtype=int32) - #ptr = _ncol_to_indptr(ncol) + # ptr = _ncol_to_indptr(ncol) col = csr.col[idx] D = csr._D[idx, :].copy() del idx @@ -548,7 +554,7 @@ def transpose(self, sort=True): # row, col, _D # Retrieve all sc-indices in the new transposed array - new_sc_off = lattice.sc_index(- lattice.sc_off) + new_sc_off = lattice.sc_index(-lattice.sc_off) # Calculate the row-offsets in the new sparse geometry row += new_sc_off[lattice.sc_index(lattice.sc_off[col // size, :])] * size @@ -585,14 +591,14 @@ def transpose(self, sort=True): return T def spalign(self, other): - """ See :meth:`~sisl.sparse.SparseCSR.align` for details """ + """See :meth:`~sisl.sparse.SparseCSR.align` for details""" if isinstance(other, SparseCSR): self._csr.align(other) else: self._csr.align(other._csr) def eliminate_zeros(self, *args, **kwargs): - """ Removes all zero elements from the sparse matrix + """Removes all zero elements from the sparse matrix This is an *in-place* operation. @@ -604,7 +610,7 @@ def eliminate_zeros(self, *args, **kwargs): # Create iterations on the non-zero elements def iter_nnz(self): - """ Iterations of the non-zero elements + """Iterations of the non-zero elements An iterator on the sparse matrix with, row and column @@ -618,7 +624,7 @@ def iter_nnz(self): __iter__ = iter_nnz def create_construct(self, R, params): - """ Create a simple function for passing to the `construct` function. + """Create a simple function for passing to the `construct` function. This is simply to leviate the creation of simplistic functions needed for setting up the sparse elements. @@ -652,19 +658,22 @@ def create_construct(self, R, params): construct : routine to create the sparse matrix from a generic function (as returned from `create_construct`) """ if len(R) != len(params): - raise ValueError(f"{self.__class__.__name__}.create_construct got different lengths of `R` and `param`") + raise ValueError( + f"{self.__class__.__name__}.create_construct got different lengths of `R` and `param`" + ) def func(self, ia, atoms, atoms_xyz=None): idx = self.geometry.close(ia, R=R, atoms=atoms, atoms_xyz=atoms_xyz) for ix, p in zip(idx, params): self[ia, ix] = p + func.R = R func.params = params return func - def construct(self, func, na_iR=1000, method='rand', eta=None): - """ Automatically construct the sparse model based on a function that does the setting up of the elements + def construct(self, func, na_iR=1000, method="rand", eta=None): + """Automatically construct the sparse model based on a function that does the setting up of the elements This may be called in two variants. @@ -709,12 +718,16 @@ def construct(self, func, na_iR=1000, method='rand', eta=None): """ if not callable(func): if not isinstance(func, (tuple, list)): - raise ValueError('Passed `func` which is not a function, nor tuple/list of `R, param`') + raise ValueError( + "Passed `func` which is not a function, nor tuple/list of `R, param`" + ) if np.any(diff(self.geometry.lasto) > 1): - raise ValueError("Automatically setting a sparse model " - "for systems with atoms having more than 1 " - "orbital *must* be done by your-self. You have to define a corresponding `func`.") + raise ValueError( + "Automatically setting a sparse model " + "for systems with atoms having more than 1 " + "orbital *must* be done by your-self. You have to define a corresponding `func`." + ) # Convert to a proper function func = self.create_construct(func[0], func[1]) @@ -738,7 +751,6 @@ def construct(self, func, na_iR=1000, method='rand', eta=None): # Do the loop for ias, idxs in self.geometry.iter_block(iR=iR, method=method, R=R): - # Get all the indexed atoms... # This speeds up the searching for coordinates... idxs_xyz = self.geometry[idxs] @@ -753,11 +765,11 @@ def construct(self, func, na_iR=1000, method='rand', eta=None): @property def finalized(self): - """ Whether the contained data is finalized and non-used elements have been removed """ + """Whether the contained data is finalized and non-used elements have been removed""" return self._csr.finalized def remove(self, atoms): - """ Create a subset of this sparse matrix by removing the atoms corresponding to `atoms` + """Create a subset of this sparse matrix by removing the atoms corresponding to `atoms` Negative indices are wrapped and thus works. @@ -777,7 +789,7 @@ def remove(self, atoms): return self.sub(atoms) def sub(self, atoms): - """ Create a subset of this sparse matrix by retaining the atoms corresponding to `atoms` + """Create a subset of this sparse matrix by retaining the atoms corresponding to `atoms` Indices passed must be unique. @@ -797,7 +809,7 @@ def sub(self, atoms): pass def swap(self, a, b): - """ Swaps atoms in the sparse geometry to obtain a new order of atoms + """Swaps atoms in the sparse geometry to obtain a new order of atoms This can be used to reorder elements of a geometry. @@ -819,7 +831,7 @@ def swap(self, a, b): return self.sub(full) def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): - """ Untiles a sparse model into a minimum segment, reverse of `tile` + """Untiles a sparse model into a minimum segment, reverse of `tile` Parameters ---------- @@ -845,17 +857,19 @@ def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): # Check whether the warning exists if len(w) > 0: if issubclass(w[-1].category, SislWarning): - warn(f"{str(w[-1].message)}\n---\n" - "The sparse matrix cannot be untiled as the structure " - "cannot be tiled accordingly. ANY use of the model has been " - "relieved from sisl.") + warn( + f"{str(w[-1].message)}\n---\n" + "The sparse matrix cannot be untiled as the structure " + "cannot be tiled accordingly. ANY use of the model has been " + "relieved from sisl." + ) # Now we need to re-create number of supercells no = getattr(self, f"n{prefix}") geom_no = getattr(geom, f"n{prefix}") # orig-orbs - orig_orbs = _a.arangei(segment * geom_no, (segment+1) * geom_no) + orig_orbs = _a.arangei(segment * geom_no, (segment + 1) * geom_no) # create correct linear offset due to the segment. # Further below we will take out the linear indices by modulo and integer @@ -887,8 +901,10 @@ def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): del S if len(sub) == 0: - raise ValueError(f"{self.__class__.__name__}.untile couples to no " - "matrix elements, an empty sparse model cannot be split.") + raise ValueError( + f"{self.__class__.__name__}.untile couples to no " + "matrix elements, an empty sparse model cannot be split." + ) # Figure out the supercell indices of sub sub_sc = getattr(self.geometry, f"{prefix}2isc")(sub) @@ -925,9 +941,11 @@ def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): else: msg = f"The returned matrix will not have symmetric couplings due to sym={sym} argument." - warn(f"{self.__class__.__name__}.untile matrix has connections crossing " - "the entire unit cell. " - f"This may result in wrong behavior due to non-unique matrix elements. {msg}") + warn( + f"{self.__class__.__name__}.untile matrix has connections crossing " + "the entire unit cell. " + f"This may result in wrong behavior due to non-unique matrix elements. {msg}" + ) else: # even case @@ -937,9 +955,11 @@ def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): # or [0 1 2] # [0] -> [1] positive direction # [0] <- [2] negative direction - warn(f"{self.__class__.__name__}.untile may have connections crossing " - "the entire unit cell. " - "This may result in wrong behavior due to non-unique matrix elements.") + warn( + f"{self.__class__.__name__}.untile may have connections crossing " + "the entire unit cell. " + "This may result in wrong behavior due to non-unique matrix elements." + ) else: # we have something like @@ -958,7 +978,7 @@ def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): found = True while found: try: - if single_sub_lsc[axis0+pos_nsc+1] == pos_nsc + 1: + if single_sub_lsc[axis0 + pos_nsc + 1] == pos_nsc + 1: pos_nsc += 1 else: found = False @@ -969,7 +989,7 @@ def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): found = True while found: try: - if single_sub_lsc[axis0+neg_nsc-1] == neg_nsc - 1: + if single_sub_lsc[axis0 + neg_nsc - 1] == neg_nsc - 1: neg_nsc -= 1 else: found = False @@ -993,9 +1013,11 @@ def untile(self, prefix, reps, axis, segment=0, *args, sym=True, **kwargs): # Create the to-columns if sub_lsc.max() != -sub_lsc.min(): - raise ValueError(f"{self.__class__.__name__}.untile found inconsistent supercell matrix. " - f"The untiled sparse matrix couples to {sub_lsc} supercells but expected a symmetric set of couplings. " - "This may happen if doing multiple cuts along the same direction, or if the matrix is not correctly constructed.") + raise ValueError( + f"{self.__class__.__name__}.untile found inconsistent supercell matrix. " + f"The untiled sparse matrix couples to {sub_lsc} supercells but expected a symmetric set of couplings. " + "This may happen if doing multiple cuts along the same direction, or if the matrix is not correctly constructed." + ) # Update number of super-cells geom.set_nsc(nsc) @@ -1015,8 +1037,11 @@ def conv(dim): cols = cols % geom_no + geom.sc_index(cols_lsc) * geom_no - return csr_matrix((csr.data, cols, csr.indptr), - shape=(geom_no, geom_no * geom.n_s), dtype=self.dtype) + return csr_matrix( + (csr.data, cols, csr.indptr), + shape=(geom_no, geom_no * geom.n_s), + dtype=self.dtype, + ) Ps = [conv(dim) for dim in range(self.dim)] S = self.fromsp(geom, Ps, **self._cls_kwargs()) @@ -1028,7 +1053,7 @@ def conv(dim): return S def unrepeat(self, reps, axis, segment=0, *args, sym=True, **kwargs): - """ Unrepeats the sparse model into different parts (retaining couplings) + """Unrepeats the sparse model into different parts (retaining couplings) Please see `untile` for details, the algorithm and arguments are the same however, this is the opposite of `repeat`. @@ -1037,7 +1062,7 @@ def unrepeat(self, reps, axis, segment=0, *args, sym=True, **kwargs): return self.sub(atoms).untile(reps, axis, segment, *args, sym=sym, **kwargs) def finalize(self): - """ Finalizes the model + """Finalizes the model Finalizes the model so that all non-used elements are removed. I.e. this simply reduces the memory requirement for the sparse matrix. @@ -1047,7 +1072,7 @@ def finalize(self): self._csr.finalize() def tocsr(self, dim=0, isc=None, **kwargs): - """ Return a :class:`~scipy.sparse.csr_matrix` for the specified dimension + """Return a :class:`~scipy.sparse.csr_matrix` for the specified dimension Parameters ---------- @@ -1058,11 +1083,13 @@ def tocsr(self, dim=0, isc=None, **kwargs): the supercell index, or all (if ``isc=None``) """ if isc is not None: - raise NotImplementedError("Requesting sub-sparse has not been implemented yet") + raise NotImplementedError( + "Requesting sub-sparse has not been implemented yet" + ) return self._csr.tocsr(dim, **kwargs) def spsame(self, other): - """ Compare two sparse objects and check whether they have the same entries. + """Compare two sparse objects and check whether they have the same entries. This does not necessarily mean that the elements are the same """ @@ -1070,7 +1097,7 @@ def spsame(self, other): @classmethod def fromsp(cls, geometry, P, **kwargs): - r""" Create a sparse model from a preset `Geometry` and a list of sparse matrices + r"""Create a sparse model from a preset `Geometry` and a list of sparse matrices The passed sparse matrices are in one of `scipy.sparse` formats. @@ -1100,8 +1127,10 @@ def fromsp(cls, geometry, P, **kwargs): p._csr = p._csr.fromsp(*P, dtype=kwargs.get("dtype")) if p._size != P[0].shape[0]: - raise ValueError(f"{cls.__name__}.fromsp cannot create a new class, the geometry " - "and sparse matrices does not have coinciding dimensions size != P[0].shape[0]") + raise ValueError( + f"{cls.__name__}.fromsp cannot create a new class, the geometry " + "and sparse matrices does not have coinciding dimensions size != P[0].shape[0]" + ) return p @@ -1136,8 +1165,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if out is not None: # check that the resulting variable is indeed a sparsecsr - assert isinstance(result, SparseCSR), \ - f"{self.__class__.__name__} ({ufunc.__name__}) requires out= to match the resulting operator" + assert isinstance( + result, SparseCSR + ), f"{self.__class__.__name__} ({ufunc.__name__}) requires out= to match the resulting operator" if isinstance(result, SparseCSR): # return a copy with the sparse result into the output sparse @@ -1152,29 +1182,29 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return out def __getstate__(self): - """ Return dictionary with the current state """ + """Return dictionary with the current state""" return { - 'geometry': self.geometry.__getstate__(), - 'csr': self._csr.__getstate__() + "geometry": self.geometry.__getstate__(), + "csr": self._csr.__getstate__(), } def __setstate__(self, state): - """ Return dictionary with the current state """ + """Return dictionary with the current state""" geom = Geometry([0] * 3, Atom(1)) - geom.__setstate__(state['geometry']) + geom.__setstate__(state["geometry"]) self._geometry = geom csr = SparseCSR((2, 2, 2)) - csr.__setstate__(state['csr']) + csr.__setstate__(state["csr"]) self._csr = csr self._def_dim = -1 @set_module("sisl") class SparseAtom(_SparseGeometry): - """ Sparse object with number of rows equal to the total number of atoms in the `Geometry` """ + """Sparse object with number of rows equal to the total number of atoms in the `Geometry`""" def __getitem__(self, key): - """ Elements for the index(s) """ + """Elements for the index(s)""" dd = self._def_dim if len(key) > 2: # This may be a specification of supercell indices @@ -1190,7 +1220,7 @@ def __getitem__(self, key): return d def __setitem__(self, key, val): - """ Set or create elements in the sparse data + """Set or create elements in the sparse data Override set item for slicing operations and enables easy setting of parameters in a sparse matrix @@ -1213,7 +1243,7 @@ def _size(self): return self.geometry.na def nonzero(self, atoms=None, only_cols=False): - """ Indices row and column indices where non-zero elements exists + """Indices row and column indices where non-zero elements exists Parameters ---------- @@ -1229,7 +1259,7 @@ def nonzero(self, atoms=None, only_cols=False): return self._csr.nonzero(rows=atoms, only_cols=only_cols) def iter_nnz(self, atoms=None): - """ Iterations of the non-zero elements + """Iterations of the non-zero elements An iterator on the sparse matrix with, row and column @@ -1250,7 +1280,7 @@ def iter_nnz(self, atoms=None): yield from self._csr.iter_nnz(atoms) def set_nsc(self, *args, **kwargs): - """ Reset the number of allowed supercells in the sparse atom + """Reset the number of allowed supercells in the sparse atom If one reduces the number of supercells *any* sparse element that references the supercell will be deleted. @@ -1264,7 +1294,7 @@ def set_nsc(self, *args, **kwargs): super().set_nsc(self.na, *args, **kwargs) def untile(self, reps, axis, segment=0, *args, sym=True, **kwargs): - """ Untiles the sparse model into different parts (retaining couplings) + """Untiles the sparse model into different parts (retaining couplings) Recreates a new sparse object with only the cutted atoms in the structure. This will preserve matrix elements in the supercell. @@ -1319,10 +1349,10 @@ def untile(self, reps, axis, segment=0, *args, sym=True, **kwargs): tile : opposite of this method Geometry.untile : same as this method, see details about parameters here """ - return super().untile('a', reps, axis, segment, *args, sym=sym, **kwargs) + return super().untile("a", reps, axis, segment, *args, sym=sym, **kwargs) def sub(self, atoms): - """ Create a subset of this sparse matrix by only retaining the elements corresponding to the `atoms` + """Create a subset of this sparse matrix by only retaining the elements corresponding to the `atoms` Indices passed *MUST* be unique. @@ -1355,7 +1385,7 @@ def sub(self, atoms): return S def tile(self, reps, axis): - """ Create a tiled sparse atom object, equivalent to `Geometry.tile` + """Create a tiled sparse atom object, equivalent to `Geometry.tile` The already existing sparse elements are extrapolated to the new supercell by repeating them in blocks like the coordinates. @@ -1435,13 +1465,15 @@ def tile(self, reps, axis): # Clean-up del isc, JA - S._csr = SparseCSR((tile(D, (reps, 1)), indices.ravel(), indptr), - shape=(geom_n.na, geom_n.na_s)) + S._csr = SparseCSR( + (tile(D, (reps, 1)), indices.ravel(), indptr), + shape=(geom_n.na, geom_n.na_s), + ) return S def repeat(self, reps, axis): - """ Create a repeated sparse atom object, equivalent to `Geometry.repeat` + """Create a repeated sparse atom object, equivalent to `Geometry.repeat` The already existing sparse elements are extrapolated to the new supercell by repeating them in blocks like the coordinates. @@ -1504,7 +1536,6 @@ def repeat(self, reps, axis): A = isc[:, axis] - 1 for rep in range(reps): - # Update the offset A += 1 # Correct supercell information @@ -1524,15 +1555,16 @@ def repeat(self, reps, axis): D = tile(D, (reps, 1)) else: ntile = ftool.partial(tile, reps=(reps, 1)) - D = np.vstack(tuple(map(ntile, np.split(D, _a.cumsumi(csr.ncol[:-1]), axis=0)))) + D = np.vstack( + tuple(map(ntile, np.split(D, _a.cumsumi(csr.ncol[:-1]), axis=0))) + ) - S._csr = SparseCSR((D, indices, indptr), - shape=(geom_n.na, geom_n.na_s)) + S._csr = SparseCSR((D, indices, indptr), shape=(geom_n.na, geom_n.na_s)) return S def rij(self, dtype=np.float64): - r""" Create a sparse matrix with the distance between atoms + r"""Create a sparse matrix with the distance between atoms Parameters ---------- @@ -1547,11 +1579,11 @@ def rij(self, dtype=np.float64): structure is completed. """ R = self.Rij(dtype) - R._csr = np.sum(R._csr ** 2, axis=-1) ** 0.5 + R._csr = np.sum(R._csr**2, axis=-1) ** 0.5 return R def Rij(self, dtype=np.float64): - r""" Create a sparse matrix with vectors between atoms + r"""Create a sparse matrix with vectors between atoms Parameters ---------- @@ -1592,10 +1624,10 @@ def Rij(self, dtype=np.float64): @set_module("sisl") class SparseOrbital(_SparseGeometry): - """ Sparse object with number of rows equal to the total number of orbitals in the `Geometry` """ + """Sparse object with number of rows equal to the total number of orbitals in the `Geometry`""" def __getitem__(self, key): - """ Elements for the index(s) """ + """Elements for the index(s)""" dd = self._def_dim if len(key) > 2: # This may be a specification of supercell indices @@ -1611,7 +1643,7 @@ def __getitem__(self, key): return d def __setitem__(self, key, val): - """ Set or create elements in the sparse data + """Set or create elements in the sparse data Override set item for slicing operations and enables easy setting of parameters in a sparse matrix @@ -1634,7 +1666,7 @@ def _size(self): return self.geometry.no def edges(self, atoms=None, exclude=None, orbitals=None): - """ Retrieve edges (connections) for all `atoms` + """Retrieve edges (connections) for all `atoms` The returned edges are unique and sorted (see `numpy.unique`) and are returned in supercell indices (i.e. ``0 <= edge < self.geometry.no_s``). @@ -1654,13 +1686,19 @@ def edges(self, atoms=None, exclude=None, orbitals=None): SparseCSR.edges: the underlying routine used for extracting the edges """ if atoms is None and orbitals is None: - raise ValueError(f"{self.__class__.__name__}.edges must have either 'atoms' or 'orbitals' keyword defined.") + raise ValueError( + f"{self.__class__.__name__}.edges must have either 'atoms' or 'orbitals' keyword defined." + ) if orbitals is None: - return unique(self.geometry.o2a(self._csr.edges(self.geometry.a2o(atoms, True), exclude))) + return unique( + self.geometry.o2a( + self._csr.edges(self.geometry.a2o(atoms, True), exclude) + ) + ) return self._csr.edges(orbitals, exclude) def nonzero(self, atoms=None, only_cols=False): - """ Indices row and column indices where non-zero elements exists + """Indices row and column indices where non-zero elements exists Parameters ---------- @@ -1680,7 +1718,7 @@ def nonzero(self, atoms=None, only_cols=False): return self._csr.nonzero(rows=rows, only_cols=only_cols) def iter_nnz(self, atoms=None, orbitals=None): - """ Iterations of the non-zero elements + """Iterations of the non-zero elements An iterator on the sparse matrix with, row and column @@ -1708,7 +1746,7 @@ def iter_nnz(self, atoms=None, orbitals=None): yield from self._csr.iter_nnz(orbitals) def set_nsc(self, *args, **kwargs): - """ Reset the number of allowed supercells in the sparse orbital + """Reset the number of allowed supercells in the sparse orbital If one reduces the number of supercells *any* sparse element that references the supercell will be deleted. @@ -1722,7 +1760,7 @@ def set_nsc(self, *args, **kwargs): super().set_nsc(self.no, *args, **kwargs) def remove(self, atoms): - """ Remove a subset of this sparse matrix by only retaining the atoms corresponding to `atoms` + """Remove a subset of this sparse matrix by only retaining the atoms corresponding to `atoms` Parameters ---------- @@ -1739,7 +1777,7 @@ def remove(self, atoms): return super().remove(atoms) def remove_orbital(self, atoms, orbitals): - """ Remove a subset of orbitals on `atoms` according to `orbitals` + """Remove a subset of orbitals on `atoms` according to `orbitals` For more detailed examples, please see the equivalent (but opposite) method `sub_orbital`. @@ -1786,7 +1824,7 @@ def remove_orbital(self, atoms, orbitals): return self.sub_orbital(atoms, orbitals) def sub(self, atoms): - """ Create a subset of this sparse matrix by only retaining the atoms corresponding to `atoms` + """Create a subset of this sparse matrix by only retaining the atoms corresponding to `atoms` Negative indices are wrapped and thus works, supercell atoms are also wrapped to the unit-cell. @@ -1826,7 +1864,7 @@ def sub(self, atoms): return S def sub_orbital(self, atoms, orbitals): - r""" Retain only a subset of the orbitals on `atoms` according to `orbitals` + r"""Retain only a subset of the orbitals on `atoms` according to `orbitals` This allows one to retain only a given subset of the sparse matrix elements. @@ -1934,7 +1972,7 @@ def sub_orbital(self, atoms, orbitals): return SG def tile(self, reps, axis): - """ Create a tiled sparse orbital object, equivalent to `Geometry.tile` + """Create a tiled sparse orbital object, equivalent to `Geometry.tile` The already existing sparse elements are extrapolated to the new supercell by repeating them in blocks like the coordinates. @@ -2008,13 +2046,15 @@ def tile(self, reps, axis): # Clean-up del isc, JO - S._csr = SparseCSR((tile(D, (reps, 1)), indices.ravel(), indptr), - shape=(geom_n.no, geom_n.no_s)) + S._csr = SparseCSR( + (tile(D, (reps, 1)), indices.ravel(), indptr), + shape=(geom_n.no, geom_n.no_s), + ) return S def untile(self, reps, axis, segment=0, *args, sym=True, **kwargs): - """ Untiles the sparse model into different parts (retaining couplings) + """Untiles the sparse model into different parts (retaining couplings) Recreates a new sparse object with only the cutted atoms in the structure. This will preserve matrix elements in the supercell. @@ -2069,10 +2109,10 @@ def untile(self, reps, axis, segment=0, *args, sym=True, **kwargs): tile : opposite of this method Geometry.untile : same as this method, see details about parameters here """ - return super().untile('o', reps, axis, segment, *args, sym=sym, **kwargs) + return super().untile("o", reps, axis, segment, *args, sym=sym, **kwargs) def repeat(self, reps, axis): - """ Create a repeated sparse orbital object, equivalent to `Geometry.repeat` + """Create a repeated sparse orbital object, equivalent to `Geometry.repeat` The already existing sparse elements are extrapolated to the new supercell by repeating them in blocks like the coordinates. @@ -2118,8 +2158,9 @@ def repeat(self, reps, axis): sc_index = geom_n.sc_index # Create new indptr, indices and D - idx = array_arange(repeat(geom.firsto[:-1], reps), - repeat(geom.firsto[1:], reps)) + idx = array_arange( + repeat(geom.firsto[:-1], reps), repeat(geom.firsto[1:], reps) + ) ncol = csr.ncol[idx] # Now indptr is complete indptr = _ncol_to_indptr(ncol) @@ -2163,7 +2204,6 @@ def repeat(self, reps, axis): # Create repetitions for _ in range(reps): - # Update atomic offset OA += AO # Update the offset @@ -2180,13 +2220,12 @@ def repeat(self, reps, axis): # In the repeat we have to tile individual atomic couplings # So we should split the arrays and tile them individually - S._csr = SparseCSR((D, indices, indptr), - shape=(geom_n.no, geom_n.no_s)) + S._csr = SparseCSR((D, indices, indptr), shape=(geom_n.no, geom_n.no_s)) return S - def rij(self, what='orbital', dtype=np.float64): - r""" Create a sparse matrix with the distance between atoms/orbitals + def rij(self, what="orbital", dtype=np.float64): + r"""Create a sparse matrix with the distance between atoms/orbitals Parameters ---------- @@ -2206,11 +2245,11 @@ def rij(self, what='orbital', dtype=np.float64): structure is completed. """ R = self.Rij(what, dtype) - R._csr = np.sum(R._csr ** 2, axis=-1) ** 0.5 + R._csr = np.sum(R._csr**2, axis=-1) ** 0.5 return R - def Rij(self, what='orbital', dtype=np.float64): - r""" Create a sparse matrix with the vectors between atoms/orbitals + def Rij(self, what="orbital", dtype=np.float64): + r"""Create a sparse matrix with the vectors between atoms/orbitals Parameters ---------- @@ -2236,7 +2275,7 @@ def Rij(self, what='orbital', dtype=np.float64): ptr = self._csr.ptr col = self._csr.col - if what == 'atom': + if what == "atom": R = SparseAtom(geom, 3, dtype, nnzpr=np.amax(ncol)) Rij = geom.Rij o2a = geom.o2a @@ -2245,10 +2284,10 @@ def Rij(self, what='orbital', dtype=np.float64): orow = _a.arangei(self.shape[0]) # Loop on orbitals and atoms for io, ia in zip(orow, o2a(orow)): - coln = unique(o2a(col[ptr[io]:ptr[io]+ncol[io]])) + coln = unique(o2a(col[ptr[io] : ptr[io] + ncol[io]])) R[ia, coln] = Rij(ia, coln) - elif what in ['orbital', 'orb']: + elif what in ["orbital", "orb"]: # We create an *exact* copy of the Rij R = SparseOrbital(geom, 3, dtype, nnzpr=1) Rij = geom.oRij @@ -2266,12 +2305,14 @@ def Rij(self, what='orbital', dtype=np.float64): R._csr._D[sl, :] = Rij(io, col[sl]) else: - raise ValueError(self.__class__.__name__ + '.Rij "what" is not one of [atom, orbital].') + raise ValueError( + self.__class__.__name__ + '.Rij "what" is not one of [atom, orbital].' + ) return R def add(self, other, axis=None, offset=(0, 0, 0)): - r""" Add two sparse matrices by adding the parameters to one set. The final matrix will have no couplings between `self` and `other` + r"""Add two sparse matrices by adding the parameters to one set. The final matrix will have no couplings between `self` and `other` The final sparse matrix will not have any couplings between `self` and `other`. Not even if they have commensurate overlapping regions. If you want to create couplings you have to use `append` but that @@ -2295,13 +2336,22 @@ def add(self, other, axis=None, offset=(0, 0, 0)): """ # Check that the sparse matrices are compatible if not (type(self) is type(other)): - raise ValueError(self.__class__.__name__ + f'.add requires other to be of same type: {other.__class__.__name__}') + raise ValueError( + self.__class__.__name__ + + f".add requires other to be of same type: {other.__class__.__name__}" + ) if self.dtype != other.dtype: - raise ValueError(self.__class__.__name__ + '.add requires the same datatypes in the two matrices.') + raise ValueError( + self.__class__.__name__ + + ".add requires the same datatypes in the two matrices." + ) if self.dim != other.dim: - raise ValueError(self.__class__.__name__ + '.add requires the same number of dimensions in the matrix.') + raise ValueError( + self.__class__.__name__ + + ".add requires the same number of dimensions in the matrix." + ) if axis is None: geom = self.geometry.add(other.geometry, offset=offset) @@ -2314,7 +2364,7 @@ def add(self, other, axis=None, offset=(0, 0, 0)): # New indices and data (the constructor for SparseCSR copies) full = self.__class__(geom, self.dim, self.dtype, 1, **self._cls_kwargs()) full._csr.ptr = concatenate((self._csr.ptr[:-1], other._csr.ptr)) - full._csr.ptr[self.no:] += self._csr.ptr[-1] + full._csr.ptr[self.no :] += self._csr.ptr[-1] full._csr.ncol = concatenate((self._csr.ncol, other._csr.ncol)) full._csr._D = concatenate((self._csr._D, other._csr._D)) full._csr._nnz = full._csr.ncol.sum() @@ -2354,7 +2404,10 @@ def add(self, other, axis=None, offset=(0, 0, 0)): except ValueError: idx_delete.append(isc) # o_idx are transferred to s_idx - transfer_idx[o_idx, :] += _a.arangei(1, other.geometry.n_s + 1)[s_idx].reshape(-1, 1) * self.geometry.no + transfer_idx[o_idx, :] += ( + _a.arangei(1, other.geometry.n_s + 1)[s_idx].reshape(-1, 1) + * self.geometry.no + ) # Remove some columns transfer_idx[idx_delete, :] = full_no_s + 1 # Clean-up to not confuse the rest of the algorithm @@ -2362,8 +2415,8 @@ def add(self, other, axis=None, offset=(0, 0, 0)): # Now figure out if the supercells can be kept, at all... # find SC indices in other corresponding to self - #o_idx_uc = other.geometry.lattice.sc_index([0] * 3) - #o_idx_sc = _a.arangei(other.geometry.lattice.n_s) + # o_idx_uc = other.geometry.lattice.sc_index([0] * 3) + # o_idx_sc = _a.arangei(other.geometry.lattice.n_s) # Remove couplings along axis for i in range(3): @@ -2395,7 +2448,7 @@ def add(self, other, axis=None, offset=(0, 0, 0)): return full def prepend(self, other, axis, eps=0.005, scale=1): - r""" See `append` for details + r"""See `append` for details This is currently equivalent to: @@ -2404,7 +2457,7 @@ def prepend(self, other, axis, eps=0.005, scale=1): return other.append(self, axis, eps, scale) def append(self, other, axis, eps=0.005, scale=1): - r""" Append `other` along `axis` to construct a new connected sparse matrix + r"""Append `other` along `axis` to construct a new connected sparse matrix This method tries to append two sparse geometry objects together by the following these steps: @@ -2485,24 +2538,38 @@ def append(self, other, axis, eps=0.005, scale=1): a new instance with two sparse matrices joined and appended together """ if not (type(self) is type(other)): - raise ValueError(f"{self.__class__.__name__}.append requires other to be of same type: {other.__class__.__name__}") + raise ValueError( + f"{self.__class__.__name__}.append requires other to be of same type: {other.__class__.__name__}" + ) if self.geometry.nsc[axis] > 3 or other.geometry.nsc[axis] > 3: - raise ValueError(f"{self.__class__.__name__}.append requires sparse-geometries to maximally " - "have 3 supercells along appending axis.") + raise ValueError( + f"{self.__class__.__name__}.append requires sparse-geometries to maximally " + "have 3 supercells along appending axis." + ) if not allclose(self.geometry.nsc, other.geometry.nsc): - raise ValueError(f"{self.__class__.__name__}.append requires sparse-geometries to have the same " - "number of supercells along all directions.") - - if not allclose(self.geometry.lattice._isc_off, other.geometry.lattice._isc_off): - raise ValueError(f"{self.__class__.__name__}.append requires supercell offsets to be the same.") + raise ValueError( + f"{self.__class__.__name__}.append requires sparse-geometries to have the same " + "number of supercells along all directions." + ) + + if not allclose( + self.geometry.lattice._isc_off, other.geometry.lattice._isc_off + ): + raise ValueError( + f"{self.__class__.__name__}.append requires supercell offsets to be the same." + ) if self.dtype != other.dtype: - raise ValueError(f"{self.__class__.__name__}.append requires the same datatypes in the two matrices.") + raise ValueError( + f"{self.__class__.__name__}.append requires the same datatypes in the two matrices." + ) if self.dim != other.dim: - raise ValueError(f"{self.__class__.__name__}.append requires the same number of dimensions in the matrix.") + raise ValueError( + f"{self.__class__.__name__}.append requires the same number of dimensions in the matrix." + ) if np.asarray(scale).size == 1: scale = np.array([scale, scale]) @@ -2517,7 +2584,7 @@ def append(self, other, axis, eps=0.005, scale=1): # New indices and data (the constructor for SparseCSR copies) full = self.__class__(geom, self.dim, self.dtype, 1, **self._cls_kwargs()) full._csr.ptr = concatenate((self._csr.ptr[:-1], other._csr.ptr)) - full._csr.ptr[self.no:] += self._csr.ptr[-1] + full._csr.ptr[self.no :] += self._csr.ptr[-1] full._csr.ncol = concatenate((self._csr.ncol, other._csr.ncol)) full._csr._D = concatenate((self._csr._D, other._csr._D)) full._csr._nnz = full._csr.ncol.sum() @@ -2534,7 +2601,9 @@ def append(self, other, axis, eps=0.005, scale=1): o_col = other._csr.col.copy() # transfer transfer_idx = _a.arangei(other.geometry.no_s).reshape(-1, other.geometry.no) - transfer_idx += _a.arangei(1, other.geometry.n_s + 1).reshape(-1, 1) * self.geometry.no + transfer_idx += ( + _a.arangei(1, other.geometry.n_s + 1).reshape(-1, 1) * self.geometry.no + ) idx = array_arange(other._csr.ptr[:-1], n=other._csr.ncol) o_col[idx] = transfer_idx.ravel()[o_col[idx]] @@ -2553,33 +2622,54 @@ def append(self, other, axis, eps=0.005, scale=1): # 1. find overlapping atoms along axis idx_s_first, idx_o_first = self.geometry.overlap(other.geometry, eps=eps) - idx_s_last, idx_o_last = self.geometry.overlap(other.geometry, eps=eps, - offset=-self.geometry.lattice.cell[axis, :], - offset_other=-other.geometry.lattice.cell[axis, :]) + idx_s_last, idx_o_last = self.geometry.overlap( + other.geometry, + eps=eps, + offset=-self.geometry.lattice.cell[axis, :], + offset_other=-other.geometry.lattice.cell[axis, :], + ) + # IFF idx_s_* contains duplicates, then we have multiple overlapping atoms which is not # allowed def _test(diff): if diff.size != diff.nonzero()[0].size: - raise ValueError(f"{self.__class__.__name__}.append requires that there is maximally one " - "atom overlapping one other atom in the other structure.") + raise ValueError( + f"{self.__class__.__name__}.append requires that there is maximally one " + "atom overlapping one other atom in the other structure." + ) + _test(diff(idx_s_first)) _test(diff(idx_s_last)) # Also ensure that atoms have the same number of orbitals in the two cases - if (not allclose(self.geometry.orbitals[idx_s_first], other.geometry.orbitals[idx_o_first])) or \ - (not allclose(self.geometry.orbitals[idx_s_last], other.geometry.orbitals[idx_o_last])): - raise ValueError(f"{self.__class__.__name__}.append requires the overlapping geometries " - "to have the same number of orbitals per atom that is to be replaced.") + if ( + not allclose( + self.geometry.orbitals[idx_s_first], + other.geometry.orbitals[idx_o_first], + ) + ) or ( + not allclose( + self.geometry.orbitals[idx_s_last], other.geometry.orbitals[idx_o_last] + ) + ): + raise ValueError( + f"{self.__class__.__name__}.append requires the overlapping geometries " + "to have the same number of orbitals per atom that is to be replaced." + ) def _check_edges_and_coordinates(spgeom, atoms, isc, err_help): # Figure out if we have found all couplings geom = spgeom.geometry # Find orbitals that we wish to exclude from the orbital connections # This ensures that we only find couplings crossing the supercell boundaries - irrelevant_sc = delete(_a.arangei(geom.lattice.n_s), geom.lattice.sc_index(isc)) + irrelevant_sc = delete( + _a.arangei(geom.lattice.n_s), geom.lattice.sc_index(isc) + ) sc_orbitals = _a.arangei(geom.no_s).reshape(geom.lattice.n_s, -1) exclude = sc_orbitals[irrelevant_sc, :].ravel() # get connections and transfer them to the unit-cell - edges_sc = geom.o2a(spgeom.edges(orbitals=_a.arangei(geom.no), exclude=exclude), True) + edges_sc = geom.o2a( + spgeom.edges(orbitals=_a.arangei(geom.no), exclude=exclude), True + ) edges_uc = geom.sc2uc(edges_sc, True) edges_valid = np.isin(edges_uc, atoms, assume_unique=True) if not np.all(edges_valid): @@ -2595,14 +2685,19 @@ def _check_edges_and_coordinates(spgeom, atoms, isc, err_help): # This will be much faster for isc in unique(isc_off): idx = (isc_off == isc).nonzero()[0] - sc_off_atoms.append("{k}: {v}".format( - k=str(geom.lattice.sc_off[isc]), - v=list2str(np.sort(uca[idx])))) + sc_off_atoms.append( + "{k}: {v}".format( + k=str(geom.lattice.sc_off[isc]), + v=list2str(np.sort(uca[idx])), + ) + ) sc_off_atoms = "\n ".join(sc_off_atoms) - raise ValueError(f"{self.__class__.__name__}.append requires matching coupling elements.\n\n" - f"The following atoms in a {err_help[1]} connection of `{err_help[0]}` super-cell " - "are connected from the unit cell, but are not found in matches:\n\n" - f"[sc-offset]: atoms\n {sc_off_atoms}") + raise ValueError( + f"{self.__class__.__name__}.append requires matching coupling elements.\n\n" + f"The following atoms in a {err_help[1]} connection of `{err_help[0]}` super-cell " + "are connected from the unit cell, but are not found in matches:\n\n" + f"[sc-offset]: atoms\n {sc_off_atoms}" + ) # setup supercells to look up isc_inplace = [None] * 3 @@ -2615,13 +2710,21 @@ def _check_edges_and_coordinates(spgeom, atoms, isc, err_help): # Check that edges and overlapping atoms are the same (or at least that the # edges are all in the overlapping region) # [self|other]: self sc-connections forward must be on left-aligned matching atoms - _check_edges_and_coordinates(self, idx_s_first, isc_forward, err_help=("self", "forward")) + _check_edges_and_coordinates( + self, idx_s_first, isc_forward, err_help=("self", "forward") + ) # [other|self]: other sc-connections forward must be on left-aligned matching atoms - _check_edges_and_coordinates(other, idx_o_first, isc_forward, err_help=("other", "forward")) + _check_edges_and_coordinates( + other, idx_o_first, isc_forward, err_help=("other", "forward") + ) # [other|self]: self sc-connections backward must be on right-aligned matching atoms - _check_edges_and_coordinates(self, idx_s_last, isc_back, err_help=("self", "backward")) + _check_edges_and_coordinates( + self, idx_s_last, isc_back, err_help=("self", "backward") + ) # [self|other]: other sc-connections backward must be on right-aligned matching atoms - _check_edges_and_coordinates(other, idx_o_last, isc_back, err_help=("other", "backward")) + _check_edges_and_coordinates( + other, idx_o_last, isc_back, err_help=("other", "backward") + ) # Now we have ensured that the overlapping coordinates and the connectivity graph # co-incide and that we can actually perform the merge. @@ -2646,14 +2749,20 @@ def _sc_index_sort(isc): # First scale all values idx_s_first = self.geometry.a2o(idx_s_first, all=True).reshape(1, -1) idx_s_last = self.geometry.a2o(idx_s_last, all=True).reshape(1, -1) - col = concatenate(((idx_s_first + idx_iscP).ravel(), - (idx_s_last + idx_iscM).ravel())) + col = concatenate( + ((idx_s_first + idx_iscP).ravel(), (idx_s_last + idx_iscM).ravel()) + ) full._csr.scale_columns(col, scale[0]) - idx_o_first = other.geometry.a2o(idx_o_first, all=True).reshape(1, -1) + self.geometry.no - idx_o_last = other.geometry.a2o(idx_o_last, all=True).reshape(1, -1) + self.geometry.no - col = concatenate(((idx_o_first + idx_iscP).ravel(), - (idx_o_last + idx_iscM).ravel())) + idx_o_first = ( + other.geometry.a2o(idx_o_first, all=True).reshape(1, -1) + self.geometry.no + ) + idx_o_last = ( + other.geometry.a2o(idx_o_last, all=True).reshape(1, -1) + self.geometry.no + ) + col = concatenate( + ((idx_o_first + idx_iscP).ravel(), (idx_o_last + idx_iscM).ravel()) + ) full._csr.scale_columns(col, scale[1]) # Clean up (they may be very large) @@ -2665,20 +2774,28 @@ def _sc_index_sort(isc): # self[0] -> self[1] changes to self[0] -> full_G[0] | other[0] # self[0] -> self[-1] changes to self[0] -> full_G[-1] | other[-1] # other[0] -> other[-1] changes to other[0] -> full_G[0] | self[0] - col_from = concatenate(((idx_o_first + idx_iscP).ravel(), - (idx_s_first + idx_iscP).ravel(), - (idx_s_last + idx_iscM).ravel(), - (idx_o_last + idx_iscM).ravel())) - col_to = concatenate(((idx_s_first + idx_iscP).ravel(), - (idx_o_first + idx_isc0).ravel(), - (idx_o_last + idx_iscM).ravel(), - (idx_s_last + idx_isc0).ravel())) + col_from = concatenate( + ( + (idx_o_first + idx_iscP).ravel(), + (idx_s_first + idx_iscP).ravel(), + (idx_s_last + idx_iscM).ravel(), + (idx_o_last + idx_iscM).ravel(), + ) + ) + col_to = concatenate( + ( + (idx_s_first + idx_iscP).ravel(), + (idx_o_first + idx_isc0).ravel(), + (idx_o_last + idx_iscM).ravel(), + (idx_s_last + idx_isc0).ravel(), + ) + ) full._csr.translate_columns(col_from, col_to) return full - def replace(self, atoms, other, other_atoms=None, eps=0.005, scale=1.): - r""" Replace `atoms` in `self` with `other_atoms` in `other` and retain couplings between them + def replace(self, atoms, other, other_atoms=None, eps=0.005, scale=1.0): + r"""Replace `atoms` in `self` with `other_atoms` in `other` and retain couplings between them This method replaces a subset of atoms in `self` with another sparse geometry retaining any couplings between them. @@ -2799,7 +2916,7 @@ def replace(self, atoms, other, other_atoms=None, eps=0.005, scale=1.): # figure out the atoms that needs replacement def get_reduced_system(sp, atoms): - """ convert the geometry in `sp` to only atoms `atoms` and return the following: + """convert the geometry in `sp` to only atoms `atoms` and return the following: 1. atoms (sanitized and no order change) 2. orbitals (ordered as `atoms` @@ -2809,9 +2926,11 @@ def get_reduced_system(sp, atoms): geom = sp.geometry atoms = _a.asarrayi(geom._sanitize_atoms(atoms)).ravel() if unique(atoms).size != atoms.size: - raise ValueError(f"{self.__class__.__name__}.replace requires a unique set of atoms") + raise ValueError( + f"{self.__class__.__name__}.replace requires a unique set of atoms" + ) orbs = geom.a2o(atoms, all=True) - #other_orbs = geom.ouc2sc(np.delete(_a.arangei(geom.no), orbs)) + # other_orbs = geom.ouc2sc(np.delete(_a.arangei(geom.no), orbs)) # Find the orbitals that these atoms connect to such that we can compare # atomic coordinates @@ -2821,7 +2940,9 @@ def get_reduced_system(sp, atoms): out_connect_atom = geom.asc2uc(out_connect_atom_sc, True) # figure out connecting back - atoms_orbs = list(map(_a.arangei, geom.firsto[atoms], geom.firsto[atoms+1])) + atoms_orbs = list( + map(_a.arangei, geom.firsto[atoms], geom.firsto[atoms + 1]) + ) in_connect_atom = [] in_connect_orb = [] @@ -2850,11 +2971,11 @@ def get_reduced_system(sp, atoms): sgeom = self.geometry s_info = get_reduced_system(self, atoms) - atoms = s_info.atoms # sanitized (no order change) + atoms = s_info.atoms # sanitized (no order change) ogeom = other.geometry o_info = get_reduced_system(other, other_atoms) - other_atoms = o_info.atoms # sanitized (no order change) + other_atoms = o_info.atoms # sanitized (no order change) # Get overlapping atoms by their offset # We need to get a 1-1 correspondance between the two connecting geometries @@ -2864,7 +2985,7 @@ def get_reduced_system(sp, atoms): # connecting regions are within some given tolerance. def create_geometry(geom, atoms): - """ Create the supercell geometry with coordinates as given """ + """Create the supercell geometry with coordinates as given""" xyz = geom.axyz(atoms) uc_atoms = geom.sc2uc(atoms) return Geometry(xyz, atoms=geom.atoms[uc_atoms]) @@ -2874,17 +2995,23 @@ def create_geometry(geom, atoms): # Atoms *inside* the replacement region that couples out sgeom_in = sgeom.sub(s_info.atom_connect.uc.IN) ogeom_in = ogeom.sub(o_info.atom_connect.uc.IN) - soverlap_in, ooverlap_in = sgeom_in.overlap(ogeom_in, eps=eps, - offset=-sgeom_in.xyz.min(0), - offset_other=-ogeom_in.xyz.min(0)) + soverlap_in, ooverlap_in = sgeom_in.overlap( + ogeom_in, + eps=eps, + offset=-sgeom_in.xyz.min(0), + offset_other=-ogeom_in.xyz.min(0), + ) # Not replacement region, i.e. the IN (above) atoms are connecting to # these atoms: sgeom_out = create_geometry(sgeom, s_info.atom_connect.sc.OUT) ogeom_out = create_geometry(ogeom, o_info.atom_connect.sc.OUT) - soverlap_out, ooverlap_out = sgeom_out.overlap(ogeom_out, eps=eps, - offset=-sgeom_out.xyz.min(0), - offset_other=-ogeom_out.xyz.min(0)) + soverlap_out, ooverlap_out = sgeom_out.overlap( + ogeom_out, + eps=eps, + offset=-sgeom_out.xyz.min(0), + offset_other=-ogeom_out.xyz.min(0), + ) # trigger for errors err_msg = "" @@ -2894,19 +3021,21 @@ def create_geometry(geom, atoms): # Before proceeding we will check whether the dimensions match. # I.e. checking that the orbitals connecting in/out are the same is important. - #print("in:") - #print(s_info.atom_connect.uc.IN) - #print(soverlap_in) - #print(o_info.atom_connect.uc.IN) - #print(ooverlap_in) - if not (len(sgeom_in) == len(soverlap_in) and - len(ogeom_in) == len(ooverlap_in)): - + # print("in:") + # print(s_info.atom_connect.uc.IN) + # print(soverlap_in) + # print(o_info.atom_connect.uc.IN) + # print(ooverlap_in) + if not ( + len(sgeom_in) == len(soverlap_in) and len(ogeom_in) == len(ooverlap_in) + ): # figure out which atoms are not connecting - s_diff = np.setdiff1d(np.arange(s_info.atom_connect.uc.IN.size), - soverlap_in) - o_diff = np.setdiff1d(np.arange(o_info.atom_connect.uc.IN.size), - ooverlap_in) + s_diff = np.setdiff1d( + np.arange(s_info.atom_connect.uc.IN.size), soverlap_in + ) + o_diff = np.setdiff1d( + np.arange(o_info.atom_connect.uc.IN.size), ooverlap_in + ) if len(s_diff) > 0 or len(o_diff) > 0: err_msg = f"""{err_msg} @@ -2923,8 +3052,9 @@ def create_geometry(geom, atoms): other: atoms not matched in 'self': {o_info.atom_connect.uc.IN[o_diff]}.""" - elif not np.allclose(sgeom_in.orbitals[soverlap_in], - ogeom_in.orbitals[ooverlap_in]): + elif not np.allclose( + sgeom_in.orbitals[soverlap_in], ogeom_in.orbitals[ooverlap_in] + ): err_msg = f"""{err_msg} Atoms in the replacement region have different number of orbitals on the atoms @@ -2935,25 +3065,27 @@ def create_geometry(geom, atoms): other orbitals: {ogeom_in.orbitals[ooverlap_in]}""" - #print("out:") - #print(s_info.atom_connect.uc.OUT) - #print(soverlap_out) - #print(o_info.atom_connect.uc.OUT) - #print(ooverlap_out) + # print("out:") + # print(s_info.atom_connect.uc.OUT) + # print(soverlap_out) + # print(o_info.atom_connect.uc.OUT) + # print(ooverlap_out) # [so]overlap_out are now in the order of [so]_info.atom_connect.out # so we still have to convert them to proper indices if used # We cannot really check the soverlap_out == len(sgeom_out) # in case we have a replaced sparse matrix in the middle of another bigger # sparse matrix. - if not (len(sgeom_out) == len(soverlap_out) and - len(ogeom_out) == len(ooverlap_out)): - + if not ( + len(sgeom_out) == len(soverlap_out) and len(ogeom_out) == len(ooverlap_out) + ): # figure out which atoms are not connecting - s_diff = np.setdiff1d(np.arange(s_info.atom_connect.sc.OUT.size), - soverlap_out) - o_diff = np.setdiff1d(np.arange(o_info.atom_connect.sc.OUT.size), - ooverlap_out) + s_diff = np.setdiff1d( + np.arange(s_info.atom_connect.sc.OUT.size), soverlap_out + ) + o_diff = np.setdiff1d( + np.arange(o_info.atom_connect.sc.OUT.size), ooverlap_out + ) if len(s_diff) > 0 or len(o_diff) > 0: err_msg = f"""{err_msg} @@ -2969,8 +3101,9 @@ def create_geometry(geom, atoms): other: atoms (in supercell) connecting to 'other_atoms' not matched in 'self': {o_info.atom_connect.sc.OUT[o_diff]}.""" - elif not np.allclose(sgeom_out.orbitals[soverlap_out], - ogeom_out.orbitals[ooverlap_out]): + elif not np.allclose( + sgeom_out.orbitals[soverlap_out], ogeom_out.orbitals[ooverlap_out] + ): err_msg = f"""{err_msg} Atoms in the connection region have different number of orbitals on the atoms. @@ -2982,7 +3115,10 @@ def create_geometry(geom, atoms): # we can only ensure the orbitals that connect *out* have the same count # For supercell connections hopping *IN* might be different due to the supercell - if len(s_info.orb_connect.sc.OUT) != len(o_info.orb_connect.sc.OUT) and not err_msg: + if ( + len(s_info.orb_connect.sc.OUT) != len(o_info.orb_connect.sc.OUT) + and not err_msg + ): err_msg = f"""{err_msg} Number of orbitals connecting to replacement region is not consistent @@ -3001,8 +3137,10 @@ def create_geometry(geom, atoms): {sgeom_in.atoms[s_]} != {ogeom_in.atoms[o_]}""" if warn_msg: - warn(f"""Inequivalent atoms found in replacement region, this may or may not be a problem -depending on your use case. Please be careful though.{warn_msg}""") + warn( + f"""Inequivalent atoms found in replacement region, this may or may not be a problem +depending on your use case. Please be careful though.{warn_msg}""" + ) warn_msg = "" S_ = s_info.atom_connect.sc.OUT @@ -3017,8 +3155,10 @@ def create_geometry(geom, atoms): {sgeom_out.atoms[s_]} != {ogeom_out.atoms[o_]}""" if warn_msg: - warn(f"""Inequivalent atoms found in connection region, this may or may not be a problem -depending on your use case. Note indices in the following are supercell indices. Please be careful though.{warn_msg}""") + warn( + f"""Inequivalent atoms found in connection region, this may or may not be a problem +depending on your use case. Note indices in the following are supercell indices. Please be careful though.{warn_msg}""" + ) # clean-up to make it clear that we are not going to use them. del sgeom_out, ogeom_out @@ -3027,7 +3167,7 @@ def create_geometry(geom, atoms): ainsert_idx = atoms.min() oinsert_idx = sgeom.a2o(ainsert_idx) # this is the indices of the new atoms in the new geometry - #self_other_atoms = _a.arangei(ainsert_idx, ainsert_idx + len(other_atoms)) + # self_other_atoms = _a.arangei(ainsert_idx, ainsert_idx + len(other_atoms)) # We need to do the replacement in two steps # A. the geometry @@ -3072,8 +3212,10 @@ def a2o(geom, atoms, sc=True): ncol = insert(ncol, oinsert_idx, ocsr.ncol[o_info.orbitals]) # Create the sparse pattern - csr = SparseCSR((D, col, _ncol_to_indptr(ncol)), - shape=(geom.no, sgeom.no_s + ogeom.no_s, D.shape[1])) + csr = SparseCSR( + (D, col, _ncol_to_indptr(ncol)), + shape=(geom.no, sgeom.no_s + ogeom.no_s, D.shape[1]), + ) del D, col, ncol # Now we have merged the two sparse patterns @@ -3099,7 +3241,7 @@ def assert_unique(old, new): return old, new # 1: - #print("1:") + # print("1:") old = delete(_a.arangei(len(sgeom)), atoms) new = _a.arangei(len(old)) new[ainsert_idx:] += len(other_atoms) @@ -3110,7 +3252,7 @@ def assert_unique(old, new): rows = geom.osc2uc(new, unique=True) # 2: - #print("2:") + # print("2:") old = s_info.atom_connect.uc.IN[soverlap_in] # algorithm to get indices in other_atoms new = o_info.atom_connect.uc.IN[ooverlap_in] @@ -3130,27 +3272,29 @@ def assert_unique(old, new): convert = [[], []] # 3: - #print("3:") + # print("3:") # we have all the *inside* column indices offset by self.shape[1] old = a2o(ogeom, other_atoms, False) + self.shape[1] new = ainsert_idx + _a.arangei(len(other_atoms)) - #print("old: ", old) - #print("new: ", new) + # print("old: ", old) + # print("new: ", new) new = a2o(geom, new, False) convert[0].append(old) convert[1].append(new) rows = geom.osc2uc(new, unique=True) # 4: - #print("4:") + # print("4:") old = o_info.atom_connect.sc.OUT new = _a.emptyi(len(old)) for i, atom in enumerate(old): idx = geom.close(ogeom.axyz(atom) + offset, R=eps) - assert len(idx) == 1, f"More than 1 atom {idx} for atom {atom} = {ogeom.axyz(atom)}, {geom.axyz(idx)}" + assert ( + len(idx) == 1 + ), f"More than 1 atom {idx} for atom {atom} = {ogeom.axyz(atom)}, {geom.axyz(idx)}" new[i] = idx[0] - #print("old: ", old) - #print("new: ", new) + # print("old: ", old) + # print("new: ", new) old = a2o(ogeom, old, False) + self.shape[1] new = a2o(geom, new, False) @@ -3172,7 +3316,7 @@ def assert_unique(old, new): return out def toSparseAtom(self, dim=None, dtype=None): - """ Convert the sparse object (without data) to a new sparse object with equivalent but reduced sparse pattern + """Convert the sparse object (without data) to a new sparse object with equivalent but reduced sparse pattern This converts the orbital sparse pattern to an atomic sparse pattern. @@ -3200,11 +3344,10 @@ def toSparseAtom(self, dim=None, dtype=None): csr = self._csr # Now build the new sparse pattern - ptr = _a.emptyi(geom.na+1) + ptr = _a.emptyi(geom.na + 1) ptr[0] = 0 col = [None] * geom.na for ia in range(geom.na): - o1, o2 = geom.a2o([ia, ia + 1]) # Get current atomic elements idx = array_arange(csr.ptr[o1:o2], n=csr.ncol[o1:o2]) @@ -3215,7 +3358,7 @@ def toSparseAtom(self, dim=None, dtype=None): # Step counters col[ia] = acol - ptr[ia+1] = ptr[ia] + len(acol) + ptr[ia + 1] = ptr[ia] + len(acol) # Now we can create the sparse atomic col = np.concatenate(col, axis=0).astype(int32, copy=False) @@ -3225,5 +3368,5 @@ def toSparseAtom(self, dim=None, dtype=None): spAtom._csr.col = col spAtom._csr._D = np.zeros([len(col), dim], dtype=dtype) spAtom._csr._nnz = len(col) - spAtom._csr._finalized = True # unique returns sorted elements + spAtom._csr._finalized = True # unique returns sorted elements return spAtom diff --git a/src/sisl/tests/test_atom.py b/src/sisl/tests/test_atom.py index cca67d020f..923efdbec2 100644 --- a/src/sisl/tests/test_atom.py +++ b/src/sisl/tests/test_atom.py @@ -13,23 +13,23 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): - self.C = Atom['C'] - self.C3 = Atom('C', [-1] * 3) - self.Au = Atom('Au') + self.C = Atom["C"] + self.C3 = Atom("C", [-1] * 3) + self.Au = Atom("Au") self.PT = PeriodicTable() return t() def test1(setup): - assert setup.C == Atom['C'] + assert setup.C == Atom["C"] assert setup.C == Atom[setup.C] - assert setup.C == Atom[Atom['C']] - assert setup.C == Atom(Atom['C']) + assert setup.C == Atom[Atom["C"]] + assert setup.C == Atom(Atom["C"]) assert setup.C == Atom[Atom(6)] - assert setup.Au == Atom['Au'] + assert setup.Au == Atom["Au"] assert setup.Au != setup.C assert setup.Au == setup.Au.copy() @@ -51,20 +51,20 @@ def test_atom_unknown(): def test2(setup): - C = Atom('C', R=20) + C = Atom("C", R=20) assert setup.C != C - Au = Atom('Au', R=20) + Au = Atom("Au", R=20) assert setup.C != Au - C = Atom['C'] + C = Atom["C"] assert setup.Au != C - Au = Atom['Au'] + Au = Atom["Au"] assert setup.C != Au def test3(setup): - assert setup.C.symbol == 'C' - assert setup.C.tag == 'C' - assert setup.Au.symbol == 'Au' + assert setup.C.symbol == "C" + assert setup.C.tag == "C" + assert setup.Au.symbol == "Au" def test4(setup): @@ -99,7 +99,7 @@ def test6(setup): def test7(setup): - assert Atom(1, [-1] * 3).radius() > 0. + assert Atom(1, [-1] * 3).radius() > 0.0 assert len(str(Atom(1, [-1] * 3))) @@ -111,44 +111,46 @@ def test8(setup): def test9(setup): - a = setup.PT.Z_label(['H', 2]) + a = setup.PT.Z_label(["H", 2]) assert len(a) == 2 - assert a[0] == 'H' - assert a[1] == 'He' + assert a[0] == "H" + assert a[1] == "He" a = setup.PT.Z_label(1) - assert a == 'H' + assert a == "H" def test10(setup): - assert setup.PT.atomic_mass(1) == setup.PT.atomic_mass('H') - assert np.allclose(setup.PT.atomic_mass([1, 2]), setup.PT.atomic_mass(['H', 'He'])) + assert setup.PT.atomic_mass(1) == setup.PT.atomic_mass("H") + assert np.allclose(setup.PT.atomic_mass([1, 2]), setup.PT.atomic_mass(["H", "He"])) def test11(setup): PT = setup.PT - for m in ['calc', 'empirical', 'vdw']: - assert PT.radius(1, method=m) == PT.radius('H', method=m) - assert np.allclose(PT.radius([1, 2], method=m), PT.radius(['H', 'He'], method=m)) + for m in ["calc", "empirical", "vdw"]: + assert PT.radius(1, method=m) == PT.radius("H", method=m) + assert np.allclose( + PT.radius([1, 2], method=m), PT.radius(["H", "He"], method=m) + ) def test_fail_equal(): - assert Atom(1.2) != 2. + assert Atom(1.2) != 2.0 def test_radius1(setup): with pytest.raises(AttributeError): - setup.PT.radius(1, method='unknown') + setup.PT.radius(1, method="unknown") def test_tag1(): - a = Atom(6, tag='my-tag') - assert a.tag == 'my-tag' + a = Atom(6, tag="my-tag") + assert a.tag == "my-tag" def test_negative1(): a = Atom(-1) - assert a.symbol == 'H' - assert a.tag == 'ghost' + assert a.symbol == "H" + assert a.tag == "ghost" assert a.Z == 1 @@ -170,7 +172,7 @@ def test_iter2(): def test_charge(): r = [1, 1, 2, 2] - a = Atom(5, [Orbital(1., 1.), Orbital(1., 1.), Orbital(2.), Orbital(2.)]) + a = Atom(5, [Orbital(1.0, 1.0), Orbital(1.0, 1.0), Orbital(2.0), Orbital(2.0)]) assert len(a.q0) == 4 assert a.q0.sum() == pytest.approx(2) @@ -209,8 +211,8 @@ def test_atoms_set(): def test_charge_diff(): - o1 = Orbital(1., 1.) - o2 = Orbital(1., .5) + o1 = Orbital(1.0, 1.0) + o2 = Orbital(1.0, 0.5) a1 = Atom(5, [o1, o2, o1, o2]) a2 = Atom(5, [o1, o2, o1, o2, o1, o1]) assert len(a1.q0) == 4 @@ -220,7 +222,7 @@ def test_charge_diff(): def test_multiple_orbitals(): - o = [Orbital(1., 1.), Orbital(2., .5), Orbital(3., .75)] + o = [Orbital(1.0, 1.0), Orbital(2.0, 0.5), Orbital(3.0, 0.75)] a1 = Atom(5, o) assert len(a1) == 3 for i in range(3): @@ -241,13 +243,13 @@ def test_multiple_orbitals(): def test_multiple_orbitals_fail_io(): - o = [Orbital(1., 1.), Orbital(2., .5), Orbital(3., .75)] + o = [Orbital(1.0, 1.0), Orbital(2.0, 0.5), Orbital(3.0, 0.75)] with pytest.raises(ValueError): Atom(5, o).sub([3]) def test_multiple_orbitals_fail_len(): - o = [Orbital(1., 1.), Orbital(2., .5), Orbital(3., .75)] + o = [Orbital(1.0, 1.0), Orbital(2.0, 0.5), Orbital(3.0, 0.75)] with pytest.raises(ValueError): Atom(5, o).sub([0, 0, 0, 0, 0]) @@ -255,14 +257,19 @@ def test_multiple_orbitals_fail_len(): def test_atom_getattr_orbs(): class UOrbital(Orbital): def __init__(self, *args, **kwargs): - U = kwargs.pop("U", 0.) + U = kwargs.pop("U", 0.0) super().__init__(*args, **kwargs) self._U = U + @property def U(self): return self._U - o = [UOrbital(1., 1., U=1.), UOrbital(2., .5, U=2.), UOrbital(3., .75, U=3.)] + o = [ + UOrbital(1.0, 1.0, U=1.0), + UOrbital(2.0, 0.5, U=2.0), + UOrbital(3.0, 0.75, U=3.0), + ] a = Atom(5, o) assert np.allclose(a.U, [1, 2, 3]) # also test the callable interface @@ -301,7 +308,7 @@ def test_atom_orbitals(): assert len(a) == 2 assert a[0].tag == "x" assert a[1].tag == "y" - assert a.maxR() < 0. + assert a.maxR() < 0.0 a = Atom(5, [1.2, 1.4]) assert len(a) == 2 assert a.maxR() == pytest.approx(1.4) @@ -309,6 +316,7 @@ def test_atom_orbitals(): def test_pickle(setup): import pickle as p + sC = p.dumps(setup.C) sC3 = p.dumps(setup.C3) sAu = p.dumps(setup.Au) diff --git a/src/sisl/tests/test_atoms.py b/src/sisl/tests/test_atoms.py index 446a3b647b..946db24122 100644 --- a/src/sisl/tests/test_atoms.py +++ b/src/sisl/tests/test_atoms.py @@ -13,11 +13,11 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): - self.C = Atom['C'] - self.C3 = Atom('C', [-1] * 3) - self.Au = Atom('Au') + self.C = Atom["C"] + self.C3 = Atom("C", [-1] * 3) + self.Au = Atom("Au") self.PT = PeriodicTable() return t() @@ -25,9 +25,9 @@ def __init__(self): def test_create1(setup): atom1 = Atoms([setup.C, setup.C3, setup.Au]) - atom2 = Atoms(['C', 'C', 'Au']) - atom3 = Atoms(['C', 6, 'Au']) - atom4 = Atoms(['Au', 6, 'C']) + atom2 = Atoms(["C", "C", "Au"]) + atom3 = Atoms(["C", 6, "Au"]) + atom4 = Atoms(["Au", 6, "C"]) assert atom2 == atom3 assert atom2 != atom4 assert atom2.hassame(atom4) @@ -53,180 +53,180 @@ def test_len(setup): def test_get1(): - atoms = Atoms(['C', 'C', 'Au']) - assert atoms[2] == Atom('Au') - assert atoms['Au'] == Atom('Au') - assert atoms[0] == Atom('C') - assert atoms[1] == Atom('C') - assert atoms['C'] == Atom('C') - assert atoms[0:2] == [Atom('C')]*2 - assert atoms[1:] == [Atom('C'), Atom('Au')] + atoms = Atoms(["C", "C", "Au"]) + assert atoms[2] == Atom("Au") + assert atoms["Au"] == Atom("Au") + assert atoms[0] == Atom("C") + assert atoms[1] == Atom("C") + assert atoms["C"] == Atom("C") + assert atoms[0:2] == [Atom("C")] * 2 + assert atoms[1:] == [Atom("C"), Atom("Au")] def test_set1(): # Add new atoms to the set - atom = Atoms(['C', 'C']) - assert atom[0] == Atom('C') - assert atom[1] == Atom('C') - atom[1] = Atom('Au') - assert atom[0] == Atom('C') - assert atom[1] == Atom('Au') - atom['C'] = Atom('Au') - assert atom[0] == Atom('Au') - assert atom[1] == Atom('Au') + atom = Atoms(["C", "C"]) + assert atom[0] == Atom("C") + assert atom[1] == Atom("C") + atom[1] = Atom("Au") + assert atom[0] == Atom("C") + assert atom[1] == Atom("Au") + atom["C"] = Atom("Au") + assert atom[0] == Atom("Au") + assert atom[1] == Atom("Au") @pytest.mark.filterwarnings("ignore", message="*Replacing atom") def test_set2(): # Add new atoms to the set - atom = Atoms(['C', 'C']) - assert atom[0] == Atom('C') - assert atom[1] == Atom('C') + atom = Atoms(["C", "C"]) + assert atom[0] == Atom("C") + assert atom[1] == Atom("C") assert len(atom.atom) == 1 - atom[1] = Atom('Au', [-1] * 2) - assert atom[0] == Atom('C') - assert atom[1] != Atom('Au') - assert atom[1] == Atom('Au', [-1] * 2) + atom[1] = Atom("Au", [-1] * 2) + assert atom[0] == Atom("C") + assert atom[1] != Atom("Au") + assert atom[1] == Atom("Au", [-1] * 2) assert len(atom.atom) == 2 - atom['C'] = Atom('Au', [-1] * 2) - assert atom[0] != Atom('Au') - assert atom[0] == Atom('Au', [-1] * 2) - assert atom[1] != Atom('Au') - assert atom[1] == Atom('Au', [-1] * 2) + atom["C"] = Atom("Au", [-1] * 2) + assert atom[0] != Atom("Au") + assert atom[0] == Atom("Au", [-1] * 2) + assert atom[1] != Atom("Au") + assert atom[1] == Atom("Au", [-1] * 2) assert len(atom.atom) == 1 def test_set3(): # Add new atoms to the set - atom = Atoms(['C'] * 10) - atom[range(1, 4)] = Atom('Au', [-1] * 2) - assert atom[0] == Atom('C') + atom = Atoms(["C"] * 10) + atom[range(1, 4)] = Atom("Au", [-1] * 2) + assert atom[0] == Atom("C") for i in range(1, 4): - assert atom[i] != Atom('Au') - assert atom[i] == Atom('Au', [-1] * 2) - assert atom[4] != Atom('Au') - assert atom[4] != Atom('Au', [-1] * 2) + assert atom[i] != Atom("Au") + assert atom[i] == Atom("Au", [-1] * 2) + assert atom[4] != Atom("Au") + assert atom[4] != Atom("Au", [-1] * 2) assert len(atom.atom) == 2 - atom[1:4] = Atom('C') + atom[1:4] = Atom("C") assert len(atom.atom) == 2 @pytest.mark.filterwarnings("ignore", message="*Replacing atom") def test_replace1(): # Add new atoms to the set - atom = Atoms(['C'] * 10 + ['B'] * 2) - atom[range(1, 4)] = Atom('Au', [-1] * 2) - assert atom[0] == Atom('C') + atom = Atoms(["C"] * 10 + ["B"] * 2) + atom[range(1, 4)] = Atom("Au", [-1] * 2) + assert atom[0] == Atom("C") for i in range(1, 4): - assert atom[i] != Atom('Au') - assert atom[i] == Atom('Au', [-1] * 2) - assert atom[4] != Atom('Au') - assert atom[4] != Atom('Au', [-1] * 2) + assert atom[i] != Atom("Au") + assert atom[i] == Atom("Au", [-1] * 2) + assert atom[4] != Atom("Au") + assert atom[4] != Atom("Au", [-1] * 2) assert len(atom.atom) == 3 - atom.replace(atom[0], Atom('C', [-1] * 2)) - assert atom[0] == Atom('C', [-1] * 2) + atom.replace(atom[0], Atom("C", [-1] * 2)) + assert atom[0] == Atom("C", [-1] * 2) assert len(atom.atom) == 3 - assert atom[0] == Atom('C', [-1] * 2) + assert atom[0] == Atom("C", [-1] * 2) for i in range(4, 10): - assert atom[i] == Atom('C', [-1] * 2) + assert atom[i] == Atom("C", [-1] * 2) for i in range(1, 4): - assert atom[i] == Atom('Au', [-1] * 2) + assert atom[i] == Atom("Au", [-1] * 2) for i in range(10, 12): - assert atom[i] == Atom('B') + assert atom[i] == Atom("B") @pytest.mark.filterwarnings("ignore", message="*Substituting atom") @pytest.mark.filterwarnings("ignore", message="*Replacing atom") def test_replace2(): # Add new atoms to the set - atom = Atoms(['C'] * 10 + ['B'] * 2) - atom.replace(range(1, 4), Atom('Au', [-1] * 2)) - assert atom[0] == Atom('C') + atom = Atoms(["C"] * 10 + ["B"] * 2) + atom.replace(range(1, 4), Atom("Au", [-1] * 2)) + assert atom[0] == Atom("C") for i in range(1, 4): - assert atom[i] != Atom('Au') - assert atom[i] == Atom('Au', [-1] * 2) - assert atom[4] != Atom('Au') - assert atom[4] != Atom('Au', [-1] * 2) + assert atom[i] != Atom("Au") + assert atom[i] == Atom("Au", [-1] * 2) + assert atom[4] != Atom("Au") + assert atom[4] != Atom("Au", [-1] * 2) assert len(atom.atom) == 3 # Second replace call (equivalent to replace_atom) - atom.replace(atom[0], Atom('C', [-1] * 2)) - assert atom[0] == Atom('C', [-1] * 2) + atom.replace(atom[0], Atom("C", [-1] * 2)) + assert atom[0] == Atom("C", [-1] * 2) assert len(atom.atom) == 3 - assert atom[0] == Atom('C', [-1] * 2) + assert atom[0] == Atom("C", [-1] * 2) for i in range(4, 10): - assert atom[i] == Atom('C', [-1] * 2) + assert atom[i] == Atom("C", [-1] * 2) for i in range(1, 4): - assert atom[i] == Atom('Au', [-1] * 2) + assert atom[i] == Atom("Au", [-1] * 2) for i in range(10, 12): - assert atom[i] == Atom('B') + assert atom[i] == Atom("B") def test_append1(): # Add new atoms to the set - atom1 = Atoms(['C', 'C']) - assert atom1[0] == Atom('C') - assert atom1[1] == Atom('C') - atom2 = Atoms([Atom('C', tag='DZ'), Atom[6]]) - assert atom2[0] == Atom('C', tag='DZ') - assert atom2[1] == Atom('C') + atom1 = Atoms(["C", "C"]) + assert atom1[0] == Atom("C") + assert atom1[1] == Atom("C") + atom2 = Atoms([Atom("C", tag="DZ"), Atom[6]]) + assert atom2[0] == Atom("C", tag="DZ") + assert atom2[1] == Atom("C") atom = atom1.append(atom2) - assert atom[0] == Atom('C') - assert atom[1] == Atom('C') - assert atom[2] == Atom('C', tag='DZ') - assert atom[3] == Atom('C') + assert atom[0] == Atom("C") + assert atom[1] == Atom("C") + assert atom[2] == Atom("C", tag="DZ") + assert atom[3] == Atom("C") - atom = atom1.append(Atom(6, tag='DZ')) - assert atom[0] == Atom('C') - assert atom[1] == Atom('C') - assert atom[2] == Atom('C', tag='DZ') + atom = atom1.append(Atom(6, tag="DZ")) + assert atom[0] == Atom("C") + assert atom[1] == Atom("C") + assert atom[2] == Atom("C", tag="DZ") - atom = atom1.append([Atom(6, tag='DZ'), Atom[6]]) - assert atom[0] == Atom('C') - assert atom[1] == Atom('C') - assert atom[2] == Atom('C', tag='DZ') - assert atom[3] == Atom('C') + atom = atom1.append([Atom(6, tag="DZ"), Atom[6]]) + assert atom[0] == Atom("C") + assert atom[1] == Atom("C") + assert atom[2] == Atom("C", tag="DZ") + assert atom[3] == Atom("C") def test_compare1(): # Add new atoms to the set - atom1 = Atoms([Atom('C', tag='DZ'), Atom[6]]) - atom2 = Atoms([Atom[6], Atom('C', tag='DZ')]) + atom1 = Atoms([Atom("C", tag="DZ"), Atom[6]]) + atom2 = Atoms([Atom[6], Atom("C", tag="DZ")]) assert atom1.hassame(atom2) assert not atom1.equal(atom2) def test_in1(): # Add new atoms to the set - atom = Atoms(['C', 'C']) + atom = Atoms(["C", "C"]) assert Atom[6] in atom assert Atom[1] not in atom def test_iter1(): # Add new atoms to the set - atom = Atoms(['C', 'C']) + atom = Atoms(["C", "C"]) for a in atom.iter(): assert a == Atom[6] for a, idx in atom.iter(True): assert a == Atom[6] assert len(idx) == 2 - atom = Atoms(['C', 'Au', 'C', 'Au']) + atom = Atoms(["C", "Au", "C", "Au"]) for i, aidx in enumerate(atom.iter(True)): a, idx = aidx if i == 0: assert a == Atom[6] assert (idx == [0, 2]).all() elif i == 1: - assert a == Atom['Au'] + assert a == Atom["Au"] assert (idx == [1, 3]).all() assert len(idx) == 2 def test_reduce1(): - atom = Atoms(['C', 'Au']) + atom = Atoms(["C", "Au"]) atom = atom.sub(0) atom1 = atom.reduce() assert atom[0] == Atom[6] @@ -243,7 +243,7 @@ def test_reduce1(): def test_remove1(): - atom = Atoms(['C', 'Au']) + atom = Atoms(["C", "Au"]) atom = atom.remove(1) atom = atom.reduce() assert atom[0] == Atom[6] @@ -252,41 +252,41 @@ def test_remove1(): def test_reorder1(): - atom = Atoms(['C', 'Au']) + atom = Atoms(["C", "Au"]) atom = atom.sub(1) atom1 = atom.reorder() # Check we haven't done anything to the original Atoms object - assert atom[0] == Atom['Au'] + assert atom[0] == Atom["Au"] assert atom.specie[0] == 1 assert len(atom) == 1 assert len(atom.atom) == 2 - assert atom1[0] == Atom['Au'] + assert atom1[0] == Atom["Au"] assert atom1.specie[0] == 0 assert len(atom1) == 1 assert len(atom1.atom) == 2 # Do in-place atom.reorder(True) - assert atom[0] == Atom['Au'] + assert atom[0] == Atom["Au"] assert atom.specie[0] == 0 assert len(atom) == 1 assert len(atom.atom) == 2 def test_reorder2(): - atom1 = Atoms(['C', 'Au']) + atom1 = Atoms(["C", "Au"]) atom2 = atom1.reorder() assert atom1 == atom2 def test_charge1(): - atom = Atoms(['C', 'Au']) + atom = Atoms(["C", "Au"]) assert len(atom.q0) == 2 - assert atom.q0.sum() == pytest.approx(0.) + assert atom.q0.sum() == pytest.approx(0.0) def test_charge_diff(): - o1 = Orbital(1., 1.) - o2 = Orbital(1., .5) + o1 = Orbital(1.0, 1.0) + o2 = Orbital(1.0, 0.5) a1 = Atom(5, [o1, o2, o1, o2]) a2 = Atom(5, [o1, o2, o1, o2, o1, o1]) a = Atoms([a1, a2]) @@ -296,6 +296,6 @@ def test_charge_diff(): def test_index1(): - atom = Atoms(['C', 'Au']) + atom = Atoms(["C", "Au"]) with pytest.raises(KeyError): - atom.index(Atom('B')) + atom.index(Atom("B")) diff --git a/src/sisl/tests/test_geometry.py b/src/sisl/tests/test_geometry.py index 37f5c655c8..c52a24d3ef 100644 --- a/src/sisl/tests/test_geometry.py +++ b/src/sisl/tests/test_geometry.py @@ -20,31 +20,37 @@ Sphere, ) -_dir = osp.join('sisl') +_dir = osp.join("sisl") pytestmark = [pytest.mark.geom, pytest.mark.geometry] @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) - C = Atom(Z=6, R=[bond * 1.01]*2) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) + C = Atom(Z=6, R=[bond * 1.01] * 2) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.mol = Geometry([[i, 0, 0] for i in range(10)], lattice=[50]) + return t() class TestGeometry: - def test_objects(self, setup): str(setup.g) assert len(setup.g) == 2 @@ -61,10 +67,10 @@ def test_objects(self, setup): def test_properties(self, setup): assert 2 == len(setup.g) assert 2 == setup.g.na - assert 3*3 == setup.g.n_s - assert 2*3*3 == setup.g.na_s - assert 2*2 == setup.g.no - assert 2*2*3*3 == setup.g.no_s + assert 3 * 3 == setup.g.n_s + assert 2 * 3 * 3 == setup.g.na_s + assert 2 * 2 == setup.g.no + assert 2 * 2 * 3 * 3 == setup.g.no_s def test_iter1(self, setup): i = 0 @@ -138,19 +144,19 @@ def test_tile3(self, setup): cell[0, :] *= 2 t1 = setup.g * (2, 0) assert np.allclose(cell, t1.lattice.cell) - t = setup.g * ((2, 0), 'tile') + t = setup.g * ((2, 0), "tile") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) cell[1, :] *= 2 t1 = t * (2, 1) assert np.allclose(cell, t1.lattice.cell) - t = t * ((2, 1), 'tile') + t = t * ((2, 1), "tile") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) cell[2, :] *= 2 t1 = t * (2, 2) assert np.allclose(cell, t1.lattice.cell) - t = t * ((2, 2), 'tile') + t = t * ((2, 2), "tile") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) @@ -158,18 +164,18 @@ def test_tile3(self, setup): t = setup.g * [2, 2, 2] assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) - t = setup.g * ([2, 2, 2], 't') + t = setup.g * ([2, 2, 2], "t") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) def test_tile4(self, setup): t1 = setup.g.tile(2, 0).tile(2, 2) - t = setup.g * ([2, 0], 't') * [2, 2] + t = setup.g * ([2, 0], "t") * [2, 2] assert np.allclose(t1.xyz, t.xyz) def test_tile5(self, setup): t = setup.g.tile(2, 0).tile(2, 2) - assert np.allclose(t[:len(setup.g), :], setup.g.xyz) + assert np.allclose(t[: len(setup.g), :], setup.g.xyz) def test_repeat0(self, setup): with pytest.raises(ValueError): @@ -198,30 +204,30 @@ def test_repeat3(self, setup): cell[0, :] *= 2 t1 = setup.g.repeat(2, 0) assert np.allclose(cell, t1.lattice.cell) - t = setup.g * ((2, 0), 'repeat') + t = setup.g * ((2, 0), "repeat") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) cell[1, :] *= 2 t1 = t.repeat(2, 1) assert np.allclose(cell, t1.lattice.cell) - t = t * ((2, 1), 'r') + t = t * ((2, 1), "r") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) cell[2, :] *= 2 t1 = t.repeat(2, 2) assert np.allclose(cell, t1.lattice.cell) - t = t * ((2, 2), 'repeat') + t = t * ((2, 2), "repeat") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) # Full - t = setup.g * ([2, 2, 2], 'r') + t = setup.g * ([2, 2, 2], "r") assert np.allclose(cell, t.lattice.cell) assert np.allclose(t1.xyz, t.xyz) def test_repeat4(self, setup): t1 = setup.g.repeat(2, 0).repeat(2, 2) - t = setup.g * ([2, 0], 'repeat') * ([2, 2], 'r') + t = setup.g * ([2, 0], "repeat") * ([2, 2], "r") assert np.allclose(t1.xyz, t.xyz) def test_repeat5(self, setup): @@ -238,7 +244,9 @@ def test_sub1(self, setup): assert len(setup.g.sub([0, 1])) == 2 assert len(setup.g.sub([-1])) == 1 - assert np.allclose(setup.g.sub([0]).xyz, setup.g.sub(np.array([True, False])).xyz) + assert np.allclose( + setup.g.sub([0]).xyz, setup.g.sub(np.array([True, False])).xyz + ) def test_sub2(self, setup): assert len(setup.g.sub(range(1))) == 1 @@ -246,7 +254,7 @@ def test_sub2(self, setup): def test_fxyz(self, setup): fxyz = setup.g.fxyz - assert np.allclose(fxyz, [[0, 0, 0], [1./3, 1./3, 0]]) + assert np.allclose(fxyz, [[0, 0, 0], [1.0 / 3, 1.0 / 3, 0]]) assert np.allclose(np.dot(fxyz, setup.g.cell), setup.g.xyz) def test_axyz(self, setup): @@ -299,21 +307,21 @@ def test_ouc2sc(self, setup): def test_rij1(self, setup): assert np.allclose(setup.g.rij(0, 1), 1.42) - assert np.allclose(setup.g.rij(0, [0, 1]), [0., 1.42]) + assert np.allclose(setup.g.rij(0, [0, 1]), [0.0, 1.42]) def test_orij1(self, setup): assert np.allclose(setup.g.orij(0, 2), 1.42) - assert np.allclose(setup.g.orij(0, [0, 2]), [0., 1.42]) + assert np.allclose(setup.g.orij(0, [0, 2]), [0.0, 1.42]) def test_Rij1(self, setup): assert np.allclose(setup.g.Rij(0, 1), [1.42, 0, 0]) def test_oRij1(self, setup): - assert np.allclose(setup.g.oRij(0, 1), [0., 0, 0]) + assert np.allclose(setup.g.oRij(0, 1), [0.0, 0, 0]) assert np.allclose(setup.g.oRij(0, 2), [1.42, 0, 0]) - assert np.allclose(setup.g.oRij(0, [0, 1, 2]), [[0., 0, 0], - [0., 0, 0], - [1.42, 0, 0]]) + assert np.allclose( + setup.g.oRij(0, [0, 1, 2]), [[0.0, 0, 0], [0.0, 0, 0], [1.42, 0, 0]] + ) assert np.allclose(setup.g.oRij(0, 2), [1.42, 0, 0]) def test_untile_warns(self, setup): @@ -421,31 +429,31 @@ def test_rotation1(self, setup): assert np.allclose(rot.xyz, setup.g.xyz) def test_rotation2(self, setup): - rot = setup.g.rotate(180, "z", what='abc') + rot = setup.g.rotate(180, "z", what="abc") rot.lattice.cell[2, 2] *= -1 assert np.allclose(-rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(rot.xyz, setup.g.xyz) - rot = setup.g.rotate(np.pi, [0, 0, 1], rad=True, what='abc') + rot = setup.g.rotate(np.pi, [0, 0, 1], rad=True, what="abc") rot.lattice.cell[2, 2] *= -1 assert np.allclose(-rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(rot.xyz, setup.g.xyz) - rot = rot.rotate(180, [0, 0, 1], what='abc') + rot = rot.rotate(180, [0, 0, 1], what="abc") rot.lattice.cell[2, 2] *= -1 assert np.allclose(rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(rot.xyz, setup.g.xyz) def test_rotation3(self, setup): - rot = setup.g.rotate(180, [0, 0, 1], what='xyz') + rot = setup.g.rotate(180, [0, 0, 1], what="xyz") assert np.allclose(rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(-rot.xyz, setup.g.xyz) - rot = setup.g.rotate(np.pi, [0, 0, 1], rad=True, what='xyz') + rot = setup.g.rotate(np.pi, [0, 0, 1], rad=True, what="xyz") assert np.allclose(rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(-rot.xyz, setup.g.xyz) - rot = rot.rotate(180, "z", what='xyz') + rot = rot.rotate(180, "z", what="xyz") assert np.allclose(rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(rot.xyz, setup.g.xyz) @@ -460,12 +468,12 @@ def test_rotation4(self, setup): assert not np.allclose(ref.xyz[2], rot.xyz[2]) assert np.allclose(ref.xyz[3], rot.xyz[3]) - rot = ref.rotate(10, "z", atoms=[1, 2], what='y') + rot = ref.rotate(10, "z", atoms=[1, 2], what="y") assert ref.xyz[1, 0] == rot.xyz[1, 0] assert ref.xyz[1, 1] != rot.xyz[1, 1] assert ref.xyz[1, 2] == rot.xyz[1, 2] - rot = ref.rotate(10, "z", atoms=[1, 2], what='xy', origin=ref.xyz[2]) + rot = ref.rotate(10, "z", atoms=[1, 2], what="xy", origin=ref.xyz[2]) assert ref.xyz[1, 0] != rot.xyz[1, 0] assert ref.xyz[1, 1] != rot.xyz[1, 1] assert ref.xyz[1, 2] == rot.xyz[1, 2] @@ -513,11 +521,11 @@ def test_iter_block2(self, setup): def test_iter_shape1(self, setup): i = 0 - for ias, _ in setup.g.iter_block(method='sphere'): + for ias, _ in setup.g.iter_block(method="sphere"): i += len(ias) assert i == len(setup.g) i = 0 - for ias, _ in setup.g.iter_block(method='cube'): + for ias, _ in setup.g.iter_block(method="cube"): i += len(ias) assert i == len(setup.g) @@ -525,11 +533,11 @@ def test_iter_shape1(self, setup): def test_iter_shape2(self, setup): g = setup.g.tile(30, 0).tile(30, 1) i = 0 - for ias, _ in g.iter_block(method='sphere'): + for ias, _ in g.iter_block(method="sphere"): i += len(ias) assert i == len(g) i = 0 - for ias, _ in g.iter_block(method='cube'): + for ias, _ in g.iter_block(method="cube"): i += len(ias) assert i == len(g) i = 0 @@ -541,11 +549,11 @@ def test_iter_shape2(self, setup): def test_iter_shape3(self, setup): g = setup.g.tile(50, 0).tile(50, 1) i = 0 - for ias, _ in g.iter_block(method='sphere'): + for ias, _ in g.iter_block(method="sphere"): i += len(ias) assert i == len(g) i = 0 - for ias, _ in g.iter_block(method='cube'): + for ias, _ in g.iter_block(method="cube"): i += len(ias) assert i == len(g) i = 0 @@ -562,43 +570,43 @@ def test_append1(self, setup): for axis in [0, 1, 2]: s = setup.g.append(setup.g, axis) assert len(s) == len(setup.g) * 2 - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) s = setup.g.prepend(setup.g, axis) assert len(s) == len(setup.g) * 2 - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) s = setup.g.append(setup.g.lattice, axis) assert len(s) == len(setup.g) - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) s = setup.g.prepend(setup.g.lattice, axis) assert len(s) == len(setup.g) - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) - assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :]* 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) + assert np.allclose(s.cell[axis, :], setup.g.cell[axis, :] * 2) def test_append_raise_valueerror(self, setup): with pytest.raises(ValueError): - s = setup.g.append(setup.g, 0, offset='not') + s = setup.g.append(setup.g, 0, offset="not") def test_prepend_raise_valueerror(self, setup): with pytest.raises(ValueError): - s = setup.g.prepend(setup.g, 0, offset='not') + s = setup.g.prepend(setup.g, 0, offset="not") def test_append_prepend_offset(self, setup): for axis in [0, 1, 2]: t = setup.g.lattice.cell[axis, :].copy() - t *= 10. / (t ** 2).sum() ** 0.5 + t *= 10.0 / (t**2).sum() ** 0.5 s1 = setup.g.copy() s2 = setup.g.translate(t) - S = s1.append(s2, axis, offset='min') + S = s1.append(s2, axis, offset="min") s = setup.g.append(setup.g, axis) assert np.allclose(s.cell[axis, :], S.cell[axis, :]) assert np.allclose(s.xyz, S.xyz) - P = s2.prepend(s1, axis, offset='min') + P = s2.prepend(s1, axis, offset="min") p = setup.g.prepend(setup.g, axis) assert np.allclose(p.cell[axis, :], P.cell[axis, :]) @@ -639,12 +647,12 @@ def test_center(self, setup): assert np.allclose(g[1], g.center(atoms=[1])) assert np.allclose(np.mean(g.xyz, axis=0), g.center()) # in this case the pbc COM is equivalent to the simple one - assert np.allclose(g.center(what='mass'), g.center(what='mass:pbc')) - assert np.allclose(g.center(what='mm:xyz'), g.center(what='mm(xyz)')) + assert np.allclose(g.center(what="mass"), g.center(what="mass:pbc")) + assert np.allclose(g.center(what="mm:xyz"), g.center(what="mm(xyz)")) def test_center_raise(self, setup): with pytest.raises(ValueError): - al = setup.g.center(what='unknown') + al = setup.g.center(what="unknown") def test___add1__(self, setup): n = len(setup.g) @@ -686,12 +694,12 @@ def test___mul__(self, setup): assert g * (2, 2, 2) == g.tile(2, 0).tile(2, 1).tile(2, 2) assert g * [1, 2, 2] == g.tile(1, 0).tile(2, 1).tile(2, 2) assert g * [1, 3, 2] == g.tile(1, 0).tile(3, 1).tile(2, 2) - assert g * ([1, 3, 2], 'r') == g.repeat(1, 0).repeat(3, 1).repeat(2, 2) - assert g * ([1, 3, 2], 'repeat') == g.repeat(1, 0).repeat(3, 1).repeat(2, 2) - assert g * ([1, 3, 2], 'tile') == g.tile(1, 0).tile(3, 1).tile(2, 2) - assert g * ([1, 3, 2], 't') == g.tile(1, 0).tile(3, 1).tile(2, 2) - assert g * ([3, 2], 't') == g.tile(3, 2) - assert g * ([3, 2], 'r') == g.repeat(3, 2) + assert g * ([1, 3, 2], "r") == g.repeat(1, 0).repeat(3, 1).repeat(2, 2) + assert g * ([1, 3, 2], "repeat") == g.repeat(1, 0).repeat(3, 1).repeat(2, 2) + assert g * ([1, 3, 2], "tile") == g.tile(1, 0).tile(3, 1).tile(2, 2) + assert g * ([1, 3, 2], "t") == g.tile(1, 0).tile(3, 1).tile(2, 2) + assert g * ([3, 2], "t") == g.tile(3, 2) + assert g * ([3, 2], "r") == g.repeat(3, 2) def test_add(self, setup): double = setup.g.add(setup.g) @@ -743,20 +751,16 @@ def test_2sc(self, setup): c = setup.g.cell # check indices - assert np.all(setup.g.a2isc([1, 2]) == [[0, 0, 0], - [-1, -1, 0]]) + assert np.all(setup.g.a2isc([1, 2]) == [[0, 0, 0], [-1, -1, 0]]) assert np.all(setup.g.a2isc(2) == [-1, -1, 0]) assert np.allclose(setup.g.a2sc(2), -c[0, :] - c[1, :]) - assert np.all(setup.g.o2isc([1, 5]) == [[0, 0, 0], - [-1, -1, 0]]) + assert np.all(setup.g.o2isc([1, 5]) == [[0, 0, 0], [-1, -1, 0]]) assert np.all(setup.g.o2isc(5) == [-1, -1, 0]) assert np.allclose(setup.g.o2sc(5), -c[0, :] - c[1, :]) # Check off-sets - assert np.allclose(setup.g.a2sc([1, 2]), [[0., 0., 0.], - -c[0, :] - c[1, :]]) - assert np.allclose(setup.g.o2sc([1, 5]), [[0., 0., 0.], - -c[0, :] - c[1, :]]) + assert np.allclose(setup.g.a2sc([1, 2]), [[0.0, 0.0, 0.0], -c[0, :] - c[1, :]]) + assert np.allclose(setup.g.o2sc([1, 5]), [[0.0, 0.0, 0.0], -c[0, :] - c[1, :]]) def test_reverse(self, setup): rev = setup.g.reverse() @@ -769,13 +773,13 @@ def test_reverse(self, setup): def test_scale1(self, setup): two = setup.g.scale(2) assert len(two) == len(setup.g) - assert np.allclose(two.xyz[:, :] / 2., setup.g.xyz) + assert np.allclose(two.xyz[:, :] / 2.0, setup.g.xyz) def test_scale_vector_abc(self, setup): two = setup.g.scale([2, 1, 1], what="abc") assert len(two) == len(setup.g) # Check that cell has been scaled accordingly - assert np.allclose(two.cell[0] / 2., setup.g.cell[0]) + assert np.allclose(two.cell[0] / 2.0, setup.g.cell[0]) assert np.allclose(two.cell[1:], setup.g.cell[1:]) # Now check that fractional coordinates are still the same assert np.allclose(two.fxyz, setup.g.fxyz) @@ -784,7 +788,7 @@ def test_scale_vector_xyz(self, setup): two = setup.g.scale([2, 1, 1], what="xyz") assert len(two) == len(setup.g) # Check that cell has been scaled accordingly - assert np.allclose(two.cell[:, 0] / 2., setup.g.cell[:, 0]) + assert np.allclose(two.cell[:, 0] / 2.0, setup.g.cell[:, 0]) assert np.allclose(two.cell[:, 1:], setup.g.cell[:, 1:]) # Now check that fractional coordinates are still the same assert np.allclose(two.fxyz, setup.g.fxyz) @@ -857,8 +861,7 @@ def test_close4(self, setup): def test_close_within1(self, setup): three = range(3) for ia in setup.mol: - shapes = [Sphere(0.1, setup.mol[ia]), - Sphere(1.1, setup.mol[ia])] + shapes = [Sphere(0.1, setup.mol[ia]), Sphere(1.1, setup.mol[ia])] i = setup.mol.close(ia, R=(0.1, 1.1), atoms=three) ii = setup.mol.within(shapes, atoms=three) assert np.all(i[0] == ii[0]) @@ -867,8 +870,7 @@ def test_close_within1(self, setup): def test_close_within2(self, setup): g = setup.g.repeat(6, 0).repeat(6, 1) for ia in g: - shapes = [Sphere(0.1, g[ia]), - Sphere(1.5, g[ia])] + shapes = [Sphere(0.1, g[ia]), Sphere(1.5, g[ia])] i = g.close(ia, R=(0.1, 1.5)) ii = g.within(shapes) assert np.all(i[0] == ii[0]) @@ -876,10 +878,9 @@ def test_close_within2(self, setup): def test_close_within3(self, setup): g = setup.g.repeat(6, 0).repeat(6, 1) - args = {'ret_xyz': True, 'ret_rij': True, 'ret_isc': True} + args = {"ret_xyz": True, "ret_rij": True, "ret_isc": True} for ia in g: - shapes = [Sphere(0.1, g[ia]), - Sphere(1.5, g[ia])] + shapes = [Sphere(0.1, g[ia]), Sphere(1.5, g[ia])] i, xa, d, isc = g.close(ia, R=(0.1, 1.5), **args) ii, xai, di, isci = g.within(shapes, **args) for j in [0, 1]: @@ -891,14 +892,14 @@ def test_close_within3(self, setup): def test_within_inf1(self, setup): g = setup.g.translate([0.05] * 3) lattice_3x3 = g.lattice.tile(3, 0).tile(3, 1) - assert len(g.within_inf(lattice_3x3)[0]) == len(g) * 3 ** 2 + assert len(g.within_inf(lattice_3x3)[0]) == len(g) * 3**2 def test_within_inf_nonperiodic(self, setup): g = setup.g.copy() # Even if the geometry has nsc > 1, if we set periodic=False # we should get only the atoms in the unit cell. - g.set_nsc([3,3,1]) + g.set_nsc([3, 3, 1]) ia, xyz, isc = g.within_inf(g.lattice, periodic=[False, False, False]) @@ -925,34 +926,36 @@ def test_within_inf2(self, setup): def test_within_inf_duplicates(self, setup): g = setup.g.copy() lattice_3x3 = g.lattice.tile(3, 0).tile(3, 1) - assert len(g.within_inf(lattice_3x3)[0]) == len(g) * 3 ** 2 + 7 # 3 per vector and 1 in the upper right corner + assert ( + len(g.within_inf(lattice_3x3)[0]) == len(g) * 3**2 + 7 + ) # 3 per vector and 1 in the upper right corner def test_close_sizes(self, setup): point = 0 # Return index - idx = setup.mol.close(point, R=.1) + idx = setup.mol.close(point, R=0.1) assert len(idx) == 1 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1)) + idx = setup.mol.close(point, R=(0.1, 1.1)) assert len(idx) == 2 assert len(idx[0]) == 1 assert not isinstance(idx[0], list) # Longer - idx = setup.mol.close(point, R=(.1, 1.1, 2.1)) + idx = setup.mol.close(point, R=(0.1, 1.1, 2.1)) assert len(idx) == 3 assert len(idx[0]) == 1 # Return index - idx = setup.mol.close(point, R=.1, ret_xyz=True) + idx = setup.mol.close(point, R=0.1, ret_xyz=True) assert len(idx) == 2 assert len(idx[0]) == 1 assert len(idx[1]) == 1 - assert idx[1].shape[0] == 1 # equivalent to above + assert idx[1].shape[0] == 1 # equivalent to above assert idx[1].shape[1] == 3 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1), ret_xyz=True) + idx = setup.mol.close(point, R=(0.1, 1.1), ret_xyz=True) # [[idx-1, idx-2], [coord-1, coord-2]] assert len(idx) == 2 assert len(idx[0]) == 2 @@ -969,7 +972,9 @@ def test_close_sizes(self, setup): assert idx[1][1].shape[1] == 3 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1), ret_xyz=True, ret_rij=True, ret_isc=True) + idx = setup.mol.close( + point, R=(0.1, 1.1), ret_xyz=True, ret_rij=True, ret_isc=True + ) # [[idx-1, idx-2], [coord-1, coord-2], [dist-1, dist-2], [isc-1, isc-2]] assert len(idx) == 4 assert len(idx[0]) == 2 @@ -1000,7 +1005,7 @@ def test_close_sizes(self, setup): assert idx[3][1].shape[1] == 3 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1), ret_rij=True) + idx = setup.mol.close(point, R=(0.1, 1.1), ret_rij=True) # [[idx-1, idx-2], [dist-1, dist-2]] assert len(idx) == 2 assert len(idx[0]) == 2 @@ -1017,31 +1022,31 @@ def test_close_sizes(self, setup): assert idx[1][1].shape[0] == 1 def test_close_sizes_none(self, setup): - point = [100., 100., 100.] + point = [100.0, 100.0, 100.0] # Return index - idx = setup.mol.close(point, R=.1) + idx = setup.mol.close(point, R=0.1) assert len(idx) == 0 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1)) + idx = setup.mol.close(point, R=(0.1, 1.1)) assert len(idx) == 2 assert len(idx[0]) == 0 assert not isinstance(idx[0], list) # Longer - idx = setup.mol.close(point, R=(.1, 1.1, 2.1)) + idx = setup.mol.close(point, R=(0.1, 1.1, 2.1)) assert len(idx) == 3 assert len(idx[0]) == 0 # Return index - idx = setup.mol.close(point, R=.1, ret_xyz=True) + idx = setup.mol.close(point, R=0.1, ret_xyz=True) assert len(idx) == 2 assert len(idx[0]) == 0 assert len(idx[1]) == 0 - assert idx[1].shape[0] == 0 # equivalent to above + assert idx[1].shape[0] == 0 # equivalent to above assert idx[1].shape[1] == 3 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1), ret_xyz=True) + idx = setup.mol.close(point, R=(0.1, 1.1), ret_xyz=True) # [[idx-1, idx-2], [coord-1, coord-2]] assert len(idx) == 2 assert len(idx[0]) == 2 @@ -1058,7 +1063,7 @@ def test_close_sizes_none(self, setup): assert idx[1][1].shape[1] == 3 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1), ret_xyz=True, ret_rij=True) + idx = setup.mol.close(point, R=(0.1, 1.1), ret_xyz=True, ret_rij=True) # [[idx-1, idx-2], [coord-1, coord-2], [dist-1, dist-2]] assert len(idx) == 3 assert len(idx[0]) == 2 @@ -1082,7 +1087,7 @@ def test_close_sizes_none(self, setup): assert idx[2][1].shape[0] == 0 # Return index of two things - idx = setup.mol.close(point, R=(.1, 1.1), ret_rij=True) + idx = setup.mol.close(point, R=(0.1, 1.1), ret_rij=True) # [[idx-1, idx-2], [dist-1, dist-2]] assert len(idx) == 2 assert len(idx[0]) == 2 @@ -1108,11 +1113,11 @@ def test_bond_correct(self, setup): rib.atoms[-1] = Atom[1] ia = len(rib) - 1 # Get bond-length - idx, d = rib.close(ia, R=(.1, 1000), ret_rij=True) + idx, d = rib.close(ia, R=(0.1, 1000), ret_rij=True) i = np.argmin(d[1]) d = d[1][i] rib.bond_correct(ia, idx[1][i]) - idx, d2 = rib.close(ia, R=(.1, 1000), ret_rij=True) + idx, d2 = rib.close(ia, R=(0.1, 1000), ret_rij=True) i = np.argmin(d2[1]) d2 = d2[1][i] assert d != d2 @@ -1138,7 +1143,7 @@ def test_unit_cell_estimation2(self, setup): # Assert that it correctly calculates the bond-length in the # directions of actual distance - g1 = Geometry([[0, 0, 0], [1, 1, 0]], atoms='H', lattice=s1) + g1 = Geometry([[0, 0, 0], [1, 1, 0]], atoms="H", lattice=s1) g2 = Geometry(np.copy(g1.xyz)) for i in range(2): assert np.allclose(g1.cell[i, :], g2.cell[i, :]) @@ -1153,7 +1158,7 @@ def test_distance1(self, setup): def test_distance2(self, setup): geom = Geometry(setup.g.xyz, Atom[6]) with pytest.raises(ValueError): - d = geom.distance(R=1.42, method='unknown_numpy_function') + d = geom.distance(R=1.42, method="unknown_numpy_function") def test_distance3(self, setup): geom = setup.g.copy() @@ -1169,7 +1174,7 @@ def test_distance4(self, setup): d = geom.distance(method=np.max) assert len(d) == 1 assert np.allclose(d, [1.42]) - d = geom.distance(method='max') + d = geom.distance(method="max") assert len(d) == 1 assert np.allclose(d, [1.42]) @@ -1183,16 +1188,16 @@ def test_distance5(self, setup): def test_distance6(self, setup): # Create a 1D chain - geom = Geometry([0]*3, Atom(1, R=1.), lattice=1) + geom = Geometry([0] * 3, Atom(1, R=1.0), lattice=1) geom.set_nsc([77, 1, 1]) d = geom.distance(0) assert len(d) == 1 - assert np.allclose(d, [1.]) + assert np.allclose(d, [1.0]) # Do twice d = geom.distance(R=2) assert len(d) == 2 - assert np.allclose(d, [1., 2.]) + assert np.allclose(d, [1.0, 2.0]) # Do all d = geom.distance(R=np.inf) @@ -1202,41 +1207,41 @@ def test_distance6(self, setup): # Create a 2D grid geom.set_nsc([3, 3, 1]) - d = geom.distance(R=2, tol=[.4, .3, .2, .1]) - assert len(d) == 2 # 1, sqrt(2) + d = geom.distance(R=2, tol=[0.4, 0.3, 0.2, 0.1]) + assert len(d) == 2 # 1, sqrt(2) # Add one due arange not adding the last item - assert np.allclose(d, [1, 2 ** .5]) + assert np.allclose(d, [1, 2**0.5]) # Create a 2D grid geom.set_nsc([5, 5, 1]) - d = geom.distance(R=2, tol=[.4, .3, .2, .1]) - assert len(d) == 3 # 1, sqrt(2), 2 + d = geom.distance(R=2, tol=[0.4, 0.3, 0.2, 0.1]) + assert len(d) == 3 # 1, sqrt(2), 2 # Add one due arange not adding the last item - assert np.allclose(d, [1, 2 ** .5, 2]) + assert np.allclose(d, [1, 2**0.5, 2]) def test_distance7(self, setup): # Create a 1D chain - geom = Geometry([0]*3, Atom(1, R=1.), lattice=1) + geom = Geometry([0] * 3, Atom(1, R=1.0), lattice=1) geom.set_nsc([77, 1, 1]) # Try with a short R and a long tolerance list # We know that the tolerance list prevails, because - d = geom.distance(R=1, tol=np.ones(10) * .5) + d = geom.distance(R=1, tol=np.ones(10) * 0.5) assert len(d) == 1 - assert np.allclose(d, [1.]) + assert np.allclose(d, [1.0]) def test_distance8(self, setup): - geom = Geometry([0]*3, Atom(1, R=1.), lattice=1) + geom = Geometry([0] * 3, Atom(1, R=1.0), lattice=1) geom.set_nsc([77, 1, 1]) - d = geom.distance(0, method='min') + d = geom.distance(0, method="min") assert len(d) == 1 - d = geom.distance(0, method='median') + d = geom.distance(0, method="median") assert len(d) == 1 - d = geom.distance(0, method='mode') + d = geom.distance(0, method="mode") assert len(d) == 1 def test_optimize_nsc1(self, setup): # Create a 1D chain - geom = Geometry([0]*3, Atom(1, R=1.), lattice=1) + geom = Geometry([0] * 3, Atom(1, R=1.0), lattice=1) geom.set_nsc([77, 77, 77]) assert np.allclose(geom.optimize_nsc(), [3, 3, 3]) geom.set_nsc([77, 77, 77]) @@ -1252,7 +1257,7 @@ def test_optimize_nsc1(self, setup): def test_optimize_nsc2(self, setup): # 2 ** 0.5 ensures lattice vectors with length 1 - geom = sisl_geom.fcc(2 ** 0.5, Atom(1, R=1.0001)) + geom = sisl_geom.fcc(2**0.5, Atom(1, R=1.0001)) geom.set_nsc([77, 77, 77]) assert np.allclose(geom.optimize_nsc(), [3, 3, 3]) geom.set_nsc([77, 77, 77]) @@ -1274,39 +1279,84 @@ def test_argumentparser2(self, setup, **kwargs): p, ns = setup.g.ArgumentParser(**kwargs) # Try all options - opts = ['--origin', - '--center-of', 'mass', - '--center-of', 'xyz', - '--center-of', 'position', - '--center-of', 'cell', - '--unit-cell', 'translate', - '--unit-cell', 'mod', - '--rotate', '90', 'x', - '--rotate', '90', 'y', - '--rotate', '90', 'z', - '--add', '0,0,0', '6', - '--swap', '0', '1', - '--repeat', '2', 'x', - '--repeat', '2', 'y', - '--repeat', '2', 'z', - '--tile', '2', 'x', - '--tile', '2', 'y', - '--tile', '2', 'z', - '--untile', '2', 'z', - '--untile', '2', 'y', - '--untile', '2', 'x', + opts = [ + "--origin", + "--center-of", + "mass", + "--center-of", + "xyz", + "--center-of", + "position", + "--center-of", + "cell", + "--unit-cell", + "translate", + "--unit-cell", + "mod", + "--rotate", + "90", + "x", + "--rotate", + "90", + "y", + "--rotate", + "90", + "z", + "--add", + "0,0,0", + "6", + "--swap", + "0", + "1", + "--repeat", + "2", + "x", + "--repeat", + "2", + "y", + "--repeat", + "2", + "z", + "--tile", + "2", + "x", + "--tile", + "2", + "y", + "--tile", + "2", + "z", + "--untile", + "2", + "z", + "--untile", + "2", + "y", + "--untile", + "2", + "x", ] - if kwargs.get('limit_arguments', True): - opts.extend(['--rotate', '-90', 'x', - '--rotate', '-90', 'y', - '--rotate', '-90', 'z']) + if kwargs.get("limit_arguments", True): + opts.extend( + ["--rotate", "-90", "x", "--rotate", "-90", "y", "--rotate", "-90", "z"] + ) else: - opts.extend(['--rotate-x', ' -90', - '--rotate-y', ' -90', - '--rotate-z', ' -90', - '--repeat-x', '2', - '--repeat-y', '2', - '--repeat-z', '2']) + opts.extend( + [ + "--rotate-x", + " -90", + "--rotate-y", + " -90", + "--rotate-z", + " -90", + "--repeat-x", + "2", + "--repeat-y", + "2", + "--repeat-z", + "2", + ] + ) args = p.parse_args(opts, namespace=ns) @@ -1321,7 +1371,7 @@ def test_set_sc(self, setup): g1.set_sc(s1) assert g1.lattice == s1 assert len(deps) == 1 - + def test_set_supercell(self, setup): # check for deprecation s1 = Lattice([2, 2, 2]) @@ -1333,33 +1383,34 @@ def test_set_supercell(self, setup): def test_attach1(self, setup): g = setup.g.attach(0, setup.mol, 0, dist=1.42, axis=2) - g = setup.g.attach(0, setup.mol, 0, dist='calc', axis=2) + g = setup.g.attach(0, setup.mol, 0, dist="calc", axis=2) g = setup.g.attach(0, setup.mol, 0, dist=[0, 0, 1.42]) def test_mirror_function(self, setup): g = setup.g - for plane in ['xy', 'xz', 'yz', 'ab', 'bc', 'ac']: + for plane in ["xy", "xz", "yz", "ab", "bc", "ac"]: g.mirror(plane) - assert g.mirror('xy') == g.mirror('z') - assert g.mirror('xy') == g.mirror([0, 0, 1]) + assert g.mirror("xy") == g.mirror("z") + assert g.mirror("xy") == g.mirror([0, 0, 1]) - assert g.mirror('xy', [0]) == g.mirror([0, 0, 1], [0]) + assert g.mirror("xy", [0]) == g.mirror([0, 0, 1], [0]) def test_mirror_point(self): g = Geometry([[0, 0, 0], [0, 0, 1]]) - out = g.mirror('z') + out = g.mirror("z") assert np.allclose(out.xyz[:, 2], [0, -1]) assert np.allclose(out.xyz[:, :2], 0) - out = g.mirror('z', point=(0, 0, 0.5)) + out = g.mirror("z", point=(0, 0, 0.5)) assert np.allclose(out.xyz[:, 2], [1, 0]) assert np.allclose(out.xyz[:, :2], 0) - out = g.mirror('z', point=(0, 0, 1)) + out = g.mirror("z", point=(0, 0, 1)) assert np.allclose(out.xyz[:, 2], [2, 1]) assert np.allclose(out.xyz[:, :2], 0) def test_pickle(self, setup): import pickle as p + s = p.dumps(setup.g) n = p.loads(s) assert n == setup.g @@ -1368,32 +1419,32 @@ def test_geometry_names(self): g = sisl_geom.graphene() assert len(g.names) == 0 - g['A'] = 1 + g["A"] = 1 assert len(g.names) == 1 - g[[1, 2]] = 'B' + g[[1, 2]] = "B" assert len(g.names) == 2 - g.names.delete_name('B') + g.names.delete_name("B") assert len(g.names) == 1 # Add new group - g['B'] = [0, 2] + g["B"] = [0, 2] for name in g.names: - assert name in ['A', 'B'] + assert name in ["A", "B"] str(g) - assert np.allclose(g['B'], g[[0, 2], :]) - assert np.allclose(g.axyz('B'), g[[0, 2], :]) + assert np.allclose(g["B"], g[[0, 2], :]) + assert np.allclose(g.axyz("B"), g[[0, 2], :]) - del g.names['B'] + del g.names["B"] assert len(g.names) == 1 def test_geometry_groups_raise(self): g = sisl_geom.graphene() - g['A'] = 1 + g["A"] = 1 with pytest.raises(SislError): - g['A'] = [1, 2] + g["A"] = [1, 2] def test_geometry_as_primary_raise_nondivisable(self): g = sisl_geom.graphene() @@ -1408,20 +1459,25 @@ def test_geometry_untile_raise_nondivisable(self): def test_geometry_iR_negative_R(self): g = sisl_geom.graphene() with pytest.raises(ValueError): - g.iR(R=-1.) - - @pytest.mark.parametrize("geometry", [sisl_geom.graphene(), - sisl_geom.diamond(), - sisl_geom.sc(1.4, Atom[1]), - sisl_geom.fcc(1.4, Atom[1]), - sisl_geom.bcc(1.4, Atom[1]), - sisl_geom.hcp(1.4, Atom[1])]) + g.iR(R=-1.0) + + @pytest.mark.parametrize( + "geometry", + [ + sisl_geom.graphene(), + sisl_geom.diamond(), + sisl_geom.sc(1.4, Atom[1]), + sisl_geom.fcc(1.4, Atom[1]), + sisl_geom.bcc(1.4, Atom[1]), + sisl_geom.hcp(1.4, Atom[1]), + ], + ) def test_geometry_as_primary(self, geometry): prod = itertools.product x_reps = [1, 4, 3] y_reps = [1, 4, 5] z_reps = [1, 4, 6] - tile_rep = ['r', 't'] + tile_rep = ["r", "t"] na_primary = len(geometry) for x, y, z in prod(x_reps, y_reps, z_reps): @@ -1453,7 +1509,9 @@ def test_geometry_ase_new_to(self): from_ase = gr.new(to_ase) assert gr.equal(from_ase, R=False) - @pytest.mark.xfail(reason="pymatgen backconversion sets nsc=[3, 3, 3], we need to figure this out") + @pytest.mark.xfail( + reason="pymatgen backconversion sets nsc=[3, 3, 3], we need to figure this out" + ) def test_geometry_pymatgen_to(self): pytest.importorskip("pymatgen", reason="pymatgen not available") gr = sisl_geom.graphene() @@ -1501,7 +1559,9 @@ def test_geometry_sort_simple(): for ix in idx: assert np.all(np.diff(bi.fxyz[ix, 1]) >= -atol) - s, idx = bi.sort(axis=0, ascending=False, lattice=1, vector=[0, 0, 1], ret_atoms=True) + s, idx = bi.sort( + axis=0, ascending=False, lattice=1, vector=[0, 0, 1], ret_atoms=True + ) assert np.all(np.diff(s.xyz[:, 0]) >= -atol) for ix in idx: # idx is according to bi @@ -1526,7 +1586,9 @@ def test_geometry_sort_int(): for ix in idx: assert np.all(np.diff(bi.fxyz[ix, 1]) >= -atol) - s, idx = bi.sort(ascending1=True, axis15=0, ascending0=False, lattice235=1, ret_atoms=True) + s, idx = bi.sort( + ascending1=True, axis15=0, ascending0=False, lattice235=1, ret_atoms=True + ) assert np.all(np.diff(s.xyz[:, 0]) >= -atol) for ix in idx: # idx is according to bi @@ -1551,6 +1613,7 @@ def test_geometry_sort_func(): def reverse(geometry, atoms, **kwargs): return atoms[::-1] + atoms = [[2, 0], [3, 1]] out = bi.sort(func=reverse, atoms=atoms) @@ -1578,31 +1641,38 @@ def test_geometry_sort_func_sort(): # Sort according to another cell fractional coordinates fcc = sisl_geom.fcc(2.4, Atom(6)) + def fcc_fracs(axis): def _(geometry): return np.dot(geometry.xyz, fcc.icell.T)[:, axis] + return _ + out = bi.sort(func_sort=(fcc_fracs(0), fcc_fracs(2))) def test_geometry_sort_group(): - bi = sisl_geom.bilayer(bottom_atoms=Atom[6], top_atoms=(Atom[5], Atom[7])).tile(2, 0).repeat(2, 1) + bi = ( + sisl_geom.bilayer(bottom_atoms=Atom[6], top_atoms=(Atom[5], Atom[7])) + .tile(2, 0) + .repeat(2, 1) + ) - out = bi.sort(group='Z') + out = bi.sort(group="Z") assert np.allclose(out.atoms.Z[:4], 5) assert np.allclose(out.atoms.Z[4:12], 6) assert np.allclose(out.atoms.Z[12:16], 7) - out = bi.sort(group=('symbol', 'C', None)) + out = bi.sort(group=("symbol", "C", None)) assert np.allclose(out.atoms.Z[:8], 6) - C = bi.sort(group=('symbol', 'C', None)) - BN = bi.sort(group=('symbol', None, 'C')) - BN2 = bi.sort(group=('symbol', ['B', 'N'], 'C')) + C = bi.sort(group=("symbol", "C", None)) + BN = bi.sort(group=("symbol", None, "C")) + BN2 = bi.sort(group=("symbol", ["B", "N"], "C")) # For these simple geometries symbol and tag are the same - BN3 = bi.sort(group=('tag', ['B', 'N'], 'C')) + BN3 = bi.sort(group=("tag", ["B", "N"], "C")) # none of these atoms should be the same assert not np.any(np.isclose(C.atoms.Z, BN.atoms.Z)) @@ -1610,8 +1680,8 @@ def test_geometry_sort_group(): assert np.allclose(BN.atoms.Z, BN2.atoms.Z) assert np.allclose(BN.atoms.Z, BN3.atoms.Z) - mass = bi.sort(group='mass') - Z = bi.sort(group='Z') + mass = bi.sort(group="mass") + Z = bi.sort(group="Z") assert np.allclose(mass.atoms.Z, Z.atoms.Z) @@ -1623,7 +1693,11 @@ def test_geometry_sort_fail_keyword(): @pytest.mark.category @pytest.mark.geom_category def test_geometry_sanitize_atom_category(): - bi = sisl_geom.bilayer(bottom_atoms=Atom[6], top_atoms=(Atom[5], Atom[7])).tile(2, 0).repeat(2, 1) + bi = ( + sisl_geom.bilayer(bottom_atoms=Atom[6], top_atoms=(Atom[5], Atom[7])) + .tile(2, 0) + .repeat(2, 1) + ) C_idx = (bi.atoms.Z == 6).nonzero()[0] check_C = bi.axyz(C_idx) only_C = bi.axyz(Atom[6]) @@ -1631,26 +1705,28 @@ def test_geometry_sanitize_atom_category(): only_C = bi.axyz(bi.atoms.Z == 6) assert np.allclose(only_C, check_C) # with dict redirect - only_C = bi.axyz({'Z': 6}) + only_C = bi.axyz({"Z": 6}) assert np.allclose(only_C, check_C) # Using a dict that has multiple keys. This basically checks # that it accepts generic categories such as the AndCategory bi2 = bi.copy() bi2.atoms["C"] = Atom("C", R=1.9) - only_C = bi2.axyz({'Z': 6, "neighbours": 3}) + only_C = bi2.axyz({"Z": 6, "neighbours": 3}) assert np.allclose(only_C, check_C) tup_01 = (0, 2) list_01 = [0, 2] ndarray_01 = np.array(list_01) - assert np.allclose(bi._sanitize_atoms(tup_01), - bi._sanitize_atoms(list_01)) - assert np.allclose(bi._sanitize_atoms(ndarray_01), - bi._sanitize_atoms(list_01)) + assert np.allclose(bi._sanitize_atoms(tup_01), bi._sanitize_atoms(list_01)) + assert np.allclose(bi._sanitize_atoms(ndarray_01), bi._sanitize_atoms(list_01)) def test_geometry_sanitize_atom_shape(): - bi = sisl_geom.bilayer(bottom_atoms=Atom[6], top_atoms=(Atom[5], Atom[7])).tile(2, 0).repeat(2, 1) + bi = ( + sisl_geom.bilayer(bottom_atoms=Atom[6], top_atoms=(Atom[5], Atom[7])) + .tile(2, 0) + .repeat(2, 1) + ) cube = Cube(10) assert len(bi.axyz(cube)) != 0 @@ -1660,10 +1736,7 @@ def test_geometry_sanitize_atom_0_length(): assert len(gr.axyz([])) == 0 -@pytest.mark.parametrize("atoms", [[True, False], - (True, False), - [0], (0,) -]) +@pytest.mark.parametrize("atoms", [[True, False], (True, False), [0], (0,)]) def test_geometry_sanitize_atom_other_bool(atoms): gr = sisl_geom.graphene() assert len(gr.axyz(atoms)) == 1 @@ -1684,7 +1757,9 @@ def test_geometry_sanitize_orbs(): C_idx = (bi.atoms.Z == 6).nonzero()[0] assert np.allclose(bi._sanitize_orbs({bot: [0]}), bi.firsto[C_idx]) assert np.allclose(bi._sanitize_orbs({bot: 1}), bi.firsto[C_idx] + 1) - assert np.allclose(bi._sanitize_orbs({bot: [1, 2]}), np.add.outer(bi.firsto[C_idx], [1, 2]).ravel()) + assert np.allclose( + bi._sanitize_orbs({bot: [1, 2]}), np.add.outer(bi.firsto[C_idx], [1, 2]).ravel() + ) def test_geometry_sub_orbitals(): @@ -1703,7 +1778,7 @@ def test_geometry_sub_orbitals(): def test_geometry_new_xyz(sisl_tmp): # test that Geometry.new works - out = sisl_tmp('out.xyz', _dir) + out = sisl_tmp("out.xyz", _dir) C = Atom[6] gr = sisl_geom.graphene(atoms=C) # writing doesn't save orbital information, so we force @@ -1741,14 +1816,14 @@ def test_translate2uc_axes(): def test_as_supercell_graphene(): gr = sisl_geom.graphene() grsc = gr.as_supercell() - assert np.allclose(grsc.xyz[:len(gr)], gr.xyz) + assert np.allclose(grsc.xyz[: len(gr)], gr.xyz) assert np.allclose(grsc.axyz(np.arange(gr.na_s)), gr.axyz(np.arange(gr.na_s))) def test_as_supercell_fcc(): - g = sisl_geom.fcc(2 ** 0.5, Atom(1, R=1.0001)) + g = sisl_geom.fcc(2**0.5, Atom(1, R=1.0001)) gsc = g.as_supercell() - assert np.allclose(gsc.xyz[:len(g)], g.xyz) + assert np.allclose(gsc.xyz[: len(g)], g.xyz) assert np.allclose(gsc.axyz(np.arange(g.na_s)), g.axyz(np.arange(g.na_s))) diff --git a/src/sisl/tests/test_geometry_return.py b/src/sisl/tests/test_geometry_return.py index 1f4d10ad8b..315d0e8063 100644 --- a/src/sisl/tests/test_geometry_return.py +++ b/src/sisl/tests/test_geometry_return.py @@ -12,28 +12,36 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) C = Atom(Z=6, R=bond * 1.01) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) C = Atom(Z=6, R=[bond * 1.01] * 2) - self.g2 = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g2 = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) + return t() @pytest.mark.geometry class TestGeometryReturn: - def test_fl_o(self, setup): # first o is always one element longer than the number of atoms assert np.all(setup.g.firsto == [0, 1, 2]) @@ -61,7 +69,7 @@ def test_rij1(self, setup): def test_rij2(self, setup): d = setup.g.rij([0, 1], [0, 1]) - assert np.allclose(d, [0., 0.]) + assert np.allclose(d, [0.0, 0.0]) def test_osc2uc(self, setup): # single value @@ -83,7 +91,7 @@ def test_slice1(self, setup): assert d.shape == (2, 3) d = setup.g[1, 2] - assert d == 0. + assert d == 0.0 d = setup.g[1, 1:3] assert d.shape == (2,) d = setup.g[1, :] diff --git a/src/sisl/tests/test_grid.py b/src/sisl/tests/test_grid.py index 1d569ee54b..598b64a6a8 100644 --- a/src/sisl/tests/test_grid.py +++ b/src/sisl/tests/test_grid.py @@ -24,21 +24,25 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): alat = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * alat, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * alat, + nsc=[3, 3, 1], + ) self.g = Grid([10, 10, 100], lattice=self.lattice) - self.g.fill(2.) + self.g.fill(2.0) + return t() @pytest.mark.grid class TestGrid: - def test_print(self, setup): str(setup.g) @@ -68,7 +72,7 @@ def test_item(self, setup): assert np.allclose(setup.g[1:2, 1:2, 2:3], setup.g.grid[1:2, 1:2, 2:3]) def test_dcell(self, setup): - assert np.all(setup.g.dcell*setup.g.cell >= 0) + assert np.all(setup.g.dcell * setup.g.cell >= 0) def test_dvolume(self, setup): assert setup.g.dvolume > 0 @@ -96,11 +100,11 @@ def test_add1(self, setup): assert np.allclose(g.grid, (setup.g / 2).grid) def test_add2(self, setup): - g = setup.g + 2. + g = setup.g + 2.0 assert np.allclose(g.grid, setup.g.grid + 2) g = setup.g.copy() - g += 2. - g -= 2. + g += 2.0 + g -= 2.0 assert np.allclose(g.grid, setup.g.grid) g = setup.g + setup.g assert np.allclose(g.grid, setup.g.grid * 2) @@ -152,7 +156,10 @@ def test_average_weight(self, setup): shape = g.shape for i in range(3): w = np.zeros(shape[i]) + 0.5 - assert g.average(i, weights=w).grid.sum() == shape[0] * shape[1] * shape[2] / shape[i] + assert ( + g.average(i, weights=w).grid.sum() + == shape[0] * shape[1] * shape[2] / shape[i] + ) def test_interp(self, setup): shape = np.array(setup.g.shape, np.int32) @@ -198,7 +205,6 @@ def test_interp_extrap(self, setup): assert np.allclose(setup.g.sum(2).grid, g1.grid) def test_isosurface_orthogonal(self, setup): - pytest.importorskip("skimage", reason="scikit-image not available") # Build an empty grid @@ -216,7 +222,6 @@ def test_isosurface_orthogonal(self, setup): assert np.unique(verts[:, 2]).shape == (2,) def test_isosurface_non_orthogonal(self, setup): - pytest.importorskip("skimage", reason="scikit-image not available") # If the grid is non-orthogonal, there should be 20 unique values @@ -250,16 +255,16 @@ def test_smooth_uniform(self, setup): # With a radius of 0.7 Ang, that single value should be propagated # to the whole grid - smoothed = g.smooth(r=1., method='uniform') + smoothed = g.smooth(r=1.0, method="uniform") assert not np.any(smoothed.grid == 0) # With a sigma of 0.1 Ang, borders should still be 0 - smoothed = g.smooth(r=0.9, method='uniform') + smoothed = g.smooth(r=0.9, method="uniform") assert np.any(smoothed.grid == 0) def test_index_ndim1(self, setup): mid = np.array(setup.g.shape, np.int32) // 2 - 1 - v = [0.001, 0., 0.001] + v = [0.001, 0.0, 0.001] idx = setup.g.index(setup.lattice.center() - v) assert np.all(mid == idx) for i in range(3): @@ -272,18 +277,18 @@ def test_index_fail(self, setup): def test_index_ndim2(self, setup): mid = np.array(setup.g.shape, np.int32) // 2 - 1 - v = [0.001, 0., 0.001] - idx = setup.g.index([[0]*3, setup.lattice.center() - v]) + v = [0.001, 0.0, 0.001] + idx = setup.g.index([[0] * 3, setup.lattice.center() - v]) assert np.allclose([[0] * 3, mid], idx) for i in range(3): - idx = setup.g.index([[0]*3, setup.lattice.center() - v], axis=i) + idx = setup.g.index([[0] * 3, setup.lattice.center() - v], axis=i) assert np.allclose([[0, 0, 0][i], mid[i]], idx) def test_index_shape1(self, setup): g = setup.g.copy() n = 0 - for r in [0.5, 1., 1.5]: + for r in [0.5, 1.0, 1.5]: s = Ellipsoid(r) idx = g.index(s) assert len(idx) > n @@ -295,12 +300,12 @@ def test_index_shape1(self, setup): # offset v = g.dcell.sum(0) vd = v * 0.001 - s = Ellipsoid(1.) + s = Ellipsoid(1.0) idx0 = g.index(s) idx0.sort(0) for d in [10, 15, 20, 60, 100, 340]: idx = g.index(v * d + vd) - s = Ellipsoid(1., center=v * d + vd) + s = Ellipsoid(1.0, center=v * d + vd) idx1 = g.index(s) idx1.sort(0) assert len(idx1) == len(idx0) @@ -309,7 +314,7 @@ def test_index_shape1(self, setup): def test_index_shape2(self, setup): g = setup.g.copy() n = 0 - for r in [0.5, 1., 1.5]: + for r in [0.5, 1.0, 1.5]: s = Cuboid(r) idx = g.index(s) assert len(idx) > n @@ -321,12 +326,12 @@ def test_index_shape2(self, setup): # offset v = g.dcell.sum(0) vd = v * 0.001 - s = Cuboid(1.) + s = Cuboid(1.0) idx0 = g.index(s) idx0.sort(0) for d in [10, 15, 20, 60, 100, 340]: idx = g.index(v * d + vd) - s = Cuboid(1., center=v * d + vd) + s = Cuboid(1.0, center=v * d + vd) idx1 = g.index(s) idx1.sort(0) assert len(idx1) == len(idx0) @@ -370,9 +375,9 @@ def test_sub(self, setup): def test_remove(self, setup): for i in range(3): - assert setup.g.remove(1, i).shape[i] == setup.g.shape[i]-1 + assert setup.g.remove(1, i).shape[i] == setup.g.shape[i] - 1 for i in range(3): - assert setup.g.remove([1, 2], i).shape[i] == setup.g.shape[i]-2 + assert setup.g.remove([1, 2], i).shape[i] == setup.g.shape[i] - 2 def test_set_grid1(self, setup): g = setup.g.copy() @@ -391,7 +396,7 @@ def test_argumentparser(self, setup): def test_pyamg1(self, setup): g = setup.g.copy() - g.lattice.set_boundary_condition(g.PERIODIC) # periodic boundary conditions + g.lattice.set_boundary_condition(g.PERIODIC) # periodic boundary conditions n = np.prod(g.shape) A = csr_matrix((n, n)) b = np.zeros(A.shape[0]) @@ -416,9 +421,7 @@ def test_pyamg2(self, setup): # Nothing is actually tested other than succesfull run, # the correctness of the values are not. g = setup.g.copy() - bc = [[g.PERIODIC] * 2, - [g.NEUMANN, g.DIRICHLET], - [g.DIRICHLET, g.NEUMANN]] + bc = [[g.PERIODIC] * 2, [g.NEUMANN, g.DIRICHLET], [g.DIRICHLET, g.NEUMANN]] g.lattice.set_boundary_condition(bc) n = np.prod(g.shape) A = csr_matrix((n, n)) @@ -437,15 +440,13 @@ def test_grid_fold(): assert np.all(grid.index_fold([[-1, -1, -1]] * 2) == [3, 4, 5]) assert np.all(grid.index_fold([[-1, -1, -1]] * 2, False) == [[3, 4, 5]] * 2) - idx = [[-1, 0, 0], - [3, 0, 0]] + idx = [[-1, 0, 0], [3, 0, 0]] assert np.all(grid.index_fold(idx) == [3, 0, 0]) assert np.all(grid.index_fold(idx, False) == [[3, 0, 0]] * 2) - idx = [[3, 0, 0], - [2, 0, 0]] + idx = [[3, 0, 0], [2, 0, 0]] assert np.all(grid.index_fold(idx, False) == idx) - assert not np.all(grid.index_fold(idx) == idx) # sorted from unique + assert not np.all(grid.index_fold(idx) == idx) # sorted from unique assert np.all(grid.index_fold(idx) == np.sort(idx, axis=0)) @@ -464,7 +465,7 @@ def test_grid_tile_sc(): def test_grid_tile_geom(): - grid = Grid([4, 5, 6], geometry=Geometry([0] * 3, Atom[4], lattice=4.)) + grid = Grid([4, 5, 6], geometry=Geometry([0] * 3, Atom[4], lattice=4.0)) grid2 = grid.tile(2, 2) assert grid.shape[:2] == grid2.shape[:2] assert grid.shape[2] == grid2.shape[2] // 2 @@ -493,7 +494,7 @@ def test_grid_tile_commensurate(): def test_grid_tile_in_commensurate(): gr = geom.graphene() - lat = Lattice(4.) + lat = Lattice(4.0) grid = Grid([4, 5, 6], geometry=gr, lattice=lat) str(grid) with pytest.raises(SislError): diff --git a/src/sisl/tests/test_help.py b/src/sisl/tests/test_help.py index bcf9d61076..e5e0f6ce7f 100644 --- a/src/sisl/tests/test_help.py +++ b/src/sisl/tests/test_help.py @@ -7,7 +7,12 @@ import numpy as np import pytest -from sisl._help import array_fill_repeat, array_replace, dtype_complex_to_real, get_dtype +from sisl._help import ( + array_fill_repeat, + array_replace, + dtype_complex_to_real, + get_dtype, +) pytestmark = pytest.mark.help @@ -23,13 +28,17 @@ def test_array_fill_repeat2(): array_fill_repeat([1, 2, 3], 20) -@pytest.mark.xfail(sys.platform.startswith("win"), reason="Datatype cannot be int64 on windows") +@pytest.mark.xfail( + sys.platform.startswith("win"), reason="Datatype cannot be int64 on windows" +) def test_get_dtype1(): assert np.int32 == get_dtype(1) assert np.int64 == get_dtype(1, int=np.int64) -@pytest.mark.xfail(sys.platform.startswith("win"), reason="Datatype cannot be int64 on windows") +@pytest.mark.xfail( + sys.platform.startswith("win"), reason="Datatype cannot be int64 on windows" +) def test_dtype_complex_to_real(): for d in (np.int32, np.int64, np.float32, np.float64): assert dtype_complex_to_real(d) == d diff --git a/src/sisl/tests/test_lattice.py b/src/sisl/tests/test_lattice.py index e1f55c83b2..54e3b1252f 100644 --- a/src/sisl/tests/test_lattice.py +++ b/src/sisl/tests/test_lattice.py @@ -17,22 +17,26 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): alat = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * alat, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * alat, + nsc=[3, 3, 1], + ) + return t() class TestLattice: - def test_str(self, setup): str(setup.lattice) str(setup.lattice) - assert setup.lattice != 'Not a Lattice' + assert setup.lattice != "Not a Lattice" def test_nsc1(self, setup): lattice = setup.lattice.copy() @@ -80,10 +84,10 @@ def test_fill(self, setup): sc = setup.lattice.swapaxes(1, 2) i = sc._fill([1, 1]) assert i.dtype == np.int32 - i = sc._fill([1., 1.]) + i = sc._fill([1.0, 1.0]) assert i.dtype == np.float64 for dt in [np.int32, np.float32, np.float64, np.complex64]: - i = sc._fill([1., 1.], dt) + i = sc._fill([1.0, 1.0], dt) assert i.dtype == dt i = sc._fill(np.ones([2], dt)) assert i.dtype == dt @@ -93,7 +97,7 @@ def test_add_vacuum_direct(self, setup): for i in range(3): s = sc.add_vacuum(10, i) ax = setup.lattice.cell[i, :] - ax += ax / np.sum(ax ** 2) ** .5 * 10 + ax += ax / np.sum(ax**2) ** 0.5 * 10 assert np.allclose(ax, s.cell[i, :]) def test_add_vacuum_orthogonal(self, setup): @@ -103,11 +107,15 @@ def test_add_vacuum_orthogonal(self, setup): # now check for the skewed ones s = sc.add_vacuum(sc.length[0], 0, orthogonal_to_plane=True) - assert (s.length[0] / sc.length[0] - 1) * m.cos(m.radians(30)) == pytest.approx(1.) + assert (s.length[0] / sc.length[0] - 1) * m.cos(m.radians(30)) == pytest.approx( + 1.0 + ) # now check for the skewed ones s = sc.add_vacuum(sc.length[1], 1, orthogonal_to_plane=True) - assert (s.length[1] / sc.length[1] - 1) * m.cos(m.radians(30)) == pytest.approx(1.) + assert (s.length[1] / sc.length[1] - 1) * m.cos(m.radians(30)) == pytest.approx( + 1.0 + ) def test_add1(self, setup): sc = setup.lattice.copy() @@ -170,7 +178,6 @@ def test_swapaxes_xyz(self, setup, a, b): assert np.allclose(lattice.origin[[b, a]], sab.origin[[a, b]]) def test_swapaxes_complicated(self, setup): - # swap a couple of lattice vectors and cartesian coordinates a = "azby" b = "bxcz" @@ -201,8 +208,7 @@ def test_sc_index1(self, setup): assert len(sc_index) == setup.lattice.nsc[2] def test_sc_index2(self, setup): - sc_index = setup.lattice.sc_index([[0, 0, 0], - [1, 1, 0]]) + sc_index = setup.lattice.sc_index([[0, 0, 0], [1, 1, 0]]) s = str(sc_index) assert len(sc_index) == 2 @@ -238,11 +244,11 @@ def test_creation1(self, setup): def test_creation2(self, setup): # full cell class P(LatticeChild): - def copy(self): a = P() a.set_lattice(setup.lattice) return a + tmp1 = P() tmp1.set_lattice([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) # diagonal cell @@ -284,14 +290,18 @@ def test_creation_rotate(self, setup): assert np.allclose(param, lattice.parameters()) assert np.allclose(parama, lattice.parameters(True)) for ang in range(0, 91, 5): - s = lattice.rotate(ang, lattice.cell[0, :]).rotate(ang, lattice.cell[1, :]).rotate(ang, lattice.cell[2, :]) + s = ( + lattice.rotate(ang, lattice.cell[0, :]) + .rotate(ang, lattice.cell[1, :]) + .rotate(ang, lattice.cell[2, :]) + ) assert np.allclose(param, s.parameters()) assert np.allclose(parama, s.parameters(True)) def test_rcell(self, setup): # LAPACK inverse algorithm implicitly does # a transpose. - rcell = lin.inv(setup.lattice.cell) * 2. * np.pi + rcell = lin.inv(setup.lattice.cell) * 2.0 * np.pi assert np.allclose(rcell.T, setup.lattice.rcell) assert np.allclose(rcell.T / (2 * np.pi), setup.lattice.icell) @@ -301,15 +311,18 @@ def test_icell(self, setup): def test_translate1(self, setup): lattice = setup.lattice.translate([0, 0, 10]) assert np.allclose(lattice.cell[2, :2], setup.lattice.cell[2, :2]) - assert np.allclose(lattice.cell[2, 2], setup.lattice.cell[2, 2]+10) + assert np.allclose(lattice.cell[2, 2], setup.lattice.cell[2, 2] + 10) def test_center1(self, setup): - assert np.allclose(setup.lattice.center(), np.sum(setup.lattice.cell, axis=0) / 2) + assert np.allclose( + setup.lattice.center(), np.sum(setup.lattice.cell, axis=0) / 2 + ) for i in [0, 1, 2]: assert np.allclose(setup.lattice.center(i), setup.lattice.cell[i, :] / 2) def test_pickle(self, setup): import pickle as p + s = p.dumps(setup.lattice) n = p.loads(s) assert setup.lattice == n @@ -359,15 +372,23 @@ def test_parallel1(self, setup): def test_tile_multiply_orthogonal(self): lattice = graphene(orthogonal=True).lattice - assert np.allclose(lattice.tile(3, 0).tile(2, 1).tile(4, 2).cell, (lattice * (3, 2, 4)).cell) + assert np.allclose( + lattice.tile(3, 0).tile(2, 1).tile(4, 2).cell, (lattice * (3, 2, 4)).cell + ) assert np.allclose(lattice.tile(3, 0).tile(2, 1).cell, (lattice * [3, 2]).cell) - assert np.allclose(lattice.tile(3, 0).tile(3, 1).tile(3, 2).cell, (lattice * 3).cell) + assert np.allclose( + lattice.tile(3, 0).tile(3, 1).tile(3, 2).cell, (lattice * 3).cell + ) def test_tile_multiply_non_orthogonal(self): lattice = graphene(orthogonal=False).lattice - assert np.allclose(lattice.tile(3, 0).tile(2, 1).tile(4, 2).cell, (lattice * (3, 2, 4)).cell) + assert np.allclose( + lattice.tile(3, 0).tile(2, 1).tile(4, 2).cell, (lattice * (3, 2, 4)).cell + ) assert np.allclose(lattice.tile(3, 0).tile(2, 1).cell, (lattice * [3, 2]).cell) - assert np.allclose(lattice.tile(3, 0).tile(3, 1).tile(3, 2).cell, (lattice * 3).cell) + assert np.allclose( + lattice.tile(3, 0).tile(3, 1).tile(3, 2).cell, (lattice * 3).cell + ) def test_angle1(self, setup): g = graphene(orthogonal=True) @@ -379,15 +400,11 @@ def test_angle2(self, setup): assert lattice.angle(0, 1) == 90 assert lattice.angle(0, 2) == 90 assert lattice.angle(1, 2) == 90 - lattice = Lattice([[1, 1, 0], - [1, -1, 0], - [0, 0, 2]]) + lattice = Lattice([[1, 1, 0], [1, -1, 0], [0, 0, 2]]) assert lattice.angle(0, 1) == 90 assert lattice.angle(0, 2) == 90 assert lattice.angle(1, 2) == 90 - lattice = Lattice([[3, 4, 0], - [4, 3, 0], - [0, 0, 2]]) + lattice = Lattice([[3, 4, 0], [4, 3, 0], [0, 0, 2]]) assert lattice.angle(0, 1, rad=True) == approx(0.28379, abs=1e-4) assert lattice.angle(0, 2) == 90 assert lattice.angle(1, 2) == 90 @@ -395,7 +412,9 @@ def test_angle2(self, setup): def test_cell2length(self): gr = graphene(orthogonal=True) lattice = (gr * (40, 40, 1)).rotate(24, gr.cell[2, :]).lattice - assert np.allclose(lattice.length, (lattice.cell2length(lattice.length) ** 2).sum(1) ** 0.5) + assert np.allclose( + lattice.length, (lattice.cell2length(lattice.length) ** 2).sum(1) ** 0.5 + ) assert np.allclose(1, (lattice.cell2length(1) ** 2).sum(0)) def test_set_lattice_off_wrong_size(self, setup): @@ -405,7 +424,7 @@ def test_set_lattice_off_wrong_size(self, setup): def _dot(u, v): - """ Dot product u . v """ + """Dot product u . v""" return u[0] * v[0] + u[1] * v[1] + u[2] * v[2] @@ -493,13 +512,13 @@ def test_supercell_warn(): # boundary-condition tests ##### + @pytest.mark.parametrize("bc", list(BoundaryCondition)) def test_lattice_bc_init(bc): Lattice(1, boundary_condition=bc) Lattice(1, boundary_condition=[bc, bc, bc]) - Lattice(1, boundary_condition=[[bc, bc], - [bc, bc], - [bc, bc]]) + Lattice(1, boundary_condition=[[bc, bc], [bc, bc], [bc, bc]]) + def test_lattice_bc_set(): lat = Lattice(1, boundary_condition=Lattice.BC.PERIODIC) @@ -508,24 +527,26 @@ def test_lattice_bc_set(): assert not lat.pbc.any() assert (lat.boundary_condition == Lattice.BC.UNKNOWN).all() - lat.boundary_condition = ['per', "unkn", [3, 4]] + lat.boundary_condition = ["per", "unkn", [3, 4]] for n in "abc": lat.set_boundary_condition(**{n: "per"}) lat.set_boundary_condition(**{n: [3, Lattice.BC.UNKNOWN]}) lat.set_boundary_condition(**{n: [True, Lattice.BC.PERIODIC]}) - bc = [ - "per", - ["Dirichlet", Lattice.BC.NEUMANN], - ["un", "neu"] - ] + bc = ["per", ["Dirichlet", Lattice.BC.NEUMANN], ["un", "neu"]] lat.set_boundary_condition(bc) - assert np.all(lat.boundary_condition[1] == [Lattice.BC.DIRICHLET, Lattice.BC.NEUMANN]) + assert np.all( + lat.boundary_condition[1] == [Lattice.BC.DIRICHLET, Lattice.BC.NEUMANN] + ) assert np.all(lat.boundary_condition[2] == [Lattice.BC.UNKNOWN, Lattice.BC.NEUMANN]) lat.set_boundary_condition(["per", None, ["dirichlet", "unkno"]]) - assert np.all(lat.boundary_condition[1] == [Lattice.BC.DIRICHLET, Lattice.BC.NEUMANN]) - assert np.all(lat.boundary_condition[2] == [Lattice.BC.DIRICHLET, Lattice.BC.UNKNOWN]) + assert np.all( + lat.boundary_condition[1] == [Lattice.BC.DIRICHLET, Lattice.BC.NEUMANN] + ) + assert np.all( + lat.boundary_condition[2] == [Lattice.BC.DIRICHLET, Lattice.BC.UNKNOWN] + ) lat.set_boundary_condition("ppu") assert np.all(lat.pbc == [True, True, False]) @@ -546,4 +567,3 @@ def test_lattice_info(): lat.set_boundary_condition(b=Lattice.BC.DIRICHLET) lat.set_boundary_condition(c=Lattice.BC.PERIODIC) assert len(record) == 1 - diff --git a/src/sisl/tests/test_messages.py b/src/sisl/tests/test_messages.py index 8683b1f10d..8c1ae0c9a3 100644 --- a/src/sisl/tests/test_messages.py +++ b/src/sisl/tests/test_messages.py @@ -12,63 +12,63 @@ def test_deprecate(): with pytest.warns(sm.SislDeprecation): - sm.deprecate('Deprecation warning') + sm.deprecate("Deprecation warning") def test_deprecation(): with pytest.warns(sm.SislDeprecation): - w.warn(sm.SislDeprecation('Deprecation warning')) + w.warn(sm.SislDeprecation("Deprecation warning")) def test_warn_method(): with pytest.warns(sm.SislWarning): - sm.warn('Warning') + sm.warn("Warning") def test_warn_specific(): with pytest.warns(sm.SislWarning): - sm.warn(sm.SislWarning('Warning')) + sm.warn(sm.SislWarning("Warning")) def test_warn_category(): with pytest.warns(sm.SislWarning): - sm.warn('Warning', sm.SislWarning) + sm.warn("Warning", sm.SislWarning) def test_info_method(): with pytest.warns(sm.SislInfo): - sm.info('Information') + sm.info("Information") def test_info_specific(): with pytest.warns(sm.SislInfo): - sm.info(sm.SislInfo('Info')) + sm.info(sm.SislInfo("Info")) def test_info_category(): with pytest.warns(sm.SislInfo): - sm.info('Information', sm.SislInfo) + sm.info("Information", sm.SislInfo) def test_error(): with pytest.raises(sm.SislError): - raise sm.SislError('This is an error') + raise sm.SislError("This is an error") def test_exception(): with pytest.raises(sm.SislException): - raise sm.SislException('This is an error') + raise sm.SislException("This is an error") def test_progressbar_true(): - eta = sm.progressbar(2, 'Hello', 'unit', True) + eta = sm.progressbar(2, "Hello", "unit", True) eta.update() eta.update() eta.close() def test_progressbar_false(): - eta = sm.progressbar(2, 'Hello', 'unit', False) + eta = sm.progressbar(2, "Hello", "unit", False) eta.update() eta.update() eta.close() diff --git a/src/sisl/tests/test_namedindex.py b/src/sisl/tests/test_namedindex.py index 66a1942254..6d4f027784 100644 --- a/src/sisl/tests/test_namedindex.py +++ b/src/sisl/tests/test_namedindex.py @@ -10,31 +10,30 @@ def test_ni_init(): - ni = NamedIndex() str(ni) - ni = NamedIndex('name', [1]) + ni = NamedIndex("name", [1]) str(ni) - ni = NamedIndex(['name-1', 'name-2'], [[1], [0]]) + ni = NamedIndex(["name-1", "name-2"], [[1], [0]]) def test_ni_iter(): ni = NamedIndex() assert len(ni) == 0 - ni.add_name('name-1', [0]) + ni.add_name("name-1", [0]) assert len(ni) == 1 - ni.add_name('name-2', [1]) + ni.add_name("name-2", [1]) assert len(ni) == 2 for n in ni: - assert n in ['name-1', 'name-2'] - assert 'name-1' in ni - assert 'name-2' in ni + assert n in ["name-1", "name-2"] + assert "name-1" in ni + assert "name-2" in ni def test_ni_clear(): ni = NamedIndex() assert len(ni) == 0 - ni.add_name('name-1', [0]) + ni.add_name("name-1", [0]) assert len(ni) == 1 ni.clear() assert len(ni) == 0 @@ -42,28 +41,28 @@ def test_ni_clear(): def test_ni_copy(): ni = NamedIndex() - ni.add_name('name-1', [0]) - ni.add_name('name-2', [1]) + ni.add_name("name-1", [0]) + ni.add_name("name-2", [1]) n2 = ni.copy() assert ni._name == n2._name def test_ni_delete(): ni = NamedIndex() - ni.add_name('name-1', [0]) - ni.add_name('name-2', [1]) - ni.delete_name('name-1') + ni.add_name("name-1", [0]) + ni.add_name("name-2", [1]) + ni.delete_name("name-1") for n in ni: - assert n in ['name-2'] + assert n in ["name-2"] def test_ni_items(): ni = NamedIndex() - ni['Hello'] = [0] - ni[[1, 2]] = 'Hello-1' - assert np.all(ni['Hello'] == [0]) + ni["Hello"] = [0] + ni[[1, 2]] = "Hello-1" + assert np.all(ni["Hello"] == [0]) no = ni.remove_index(1) - assert np.all(no['Hello-1'] == [2]) + assert np.all(no["Hello-1"] == [2]) def test_ni_dict(): diff --git a/src/sisl/tests/test_oplist.py b/src/sisl/tests/test_oplist.py index f725b39be7..612324d873 100644 --- a/src/sisl/tests/test_oplist.py +++ b/src/sisl/tests/test_oplist.py @@ -23,7 +23,9 @@ def test_oplist_creation(): assert l[1] == 2 -@pytest.mark.parametrize("op", [ops.add, ops.floordiv, ops.sub, ops.mul, ops.truediv, ops.pow]) +@pytest.mark.parametrize( + "op", [ops.add, ops.floordiv, ops.sub, ops.mul, ops.truediv, ops.pow] +) @pytest.mark.parametrize("key1", [1, 2]) @pytest.mark.parametrize("key2", [1, 2]) def test_oplist_math(op, key1, key2): @@ -43,7 +45,9 @@ def test_oplist_single(op): op(d) -@pytest.mark.parametrize("op", [ops.iadd, ops.ifloordiv, ops.isub, ops.imul, ops.itruediv, ops.ipow]) +@pytest.mark.parametrize( + "op", [ops.iadd, ops.ifloordiv, ops.isub, ops.imul, ops.itruediv, ops.ipow] +) @pytest.mark.parametrize("key", [1, 2]) def test_oplist_imath(op, key): d = { @@ -87,6 +91,7 @@ def test_oplist_deco(): @oplist.decorate def my_func(): return 1 + a = my_func() assert isinstance(a, oplist) assert a[0] == 1 @@ -94,6 +99,7 @@ def my_func(): @oplist.decorate def my_func(): return [2, 3] + a = my_func() assert isinstance(a, oplist) assert len(a) == 2 @@ -102,6 +108,7 @@ def my_func(): @oplist.decorate def my_func(): return oplist([1, 2]) + a = my_func() assert isinstance(a, oplist) assert len(a) == 2 diff --git a/src/sisl/tests/test_orbital.py b/src/sisl/tests/test_orbital.py index 235dd822c7..628accb538 100644 --- a/src/sisl/tests/test_orbital.py +++ b/src/sisl/tests/test_orbital.py @@ -22,6 +22,7 @@ _max_l = len(_rspher_harm_fact) - 1 + def r_f(n): r = np.arange(n) return r, r @@ -29,7 +30,7 @@ def r_f(n): def test_spherical(): rad2 = np.pi / 45 - r, theta, phi = np.ogrid[0.1:10:0.2, -np.pi:np.pi:rad2, 0:np.pi:rad2] + r, theta, phi = np.ogrid[0.1:10:0.2, -np.pi : np.pi : rad2, 0 : np.pi : rad2] xyz = spher2cart(r, theta, phi) s = xyz.shape[:-1] r1, theta1, phi1 = cart2spher(xyz) @@ -42,44 +43,44 @@ def test_spherical(): class Test_orbital: - def test_init1(self): - assert Orbital(1.) == Orbital(1.) - assert Orbital(1., tag='none') != Orbital(1.) - assert Orbital(1., 1.) != Orbital(1.) - assert Orbital(1., 1.) != Orbital(1., 1., tag='none') + assert Orbital(1.0) == Orbital(1.0) + assert Orbital(1.0, tag="none") != Orbital(1.0) + assert Orbital(1.0, 1.0) != Orbital(1.0) + assert Orbital(1.0, 1.0) != Orbital(1.0, 1.0, tag="none") def test_basic1(self): - orb = Orbital(1.) + orb = Orbital(1.0) str(orb) - orb = Orbital(1., tag='none') + orb = Orbital(1.0, tag="none") str(orb) - orb = Orbital(1., 1., tag='none') + orb = Orbital(1.0, 1.0, tag="none") str(orb) assert orb == orb.copy() - assert orb != 1. + assert orb != 1.0 def test_copy(self): - orb = Orbital(1.) + orb = Orbital(1.0) assert orb.R == orb.copy().R - orb = Orbital(-1.) + orb = Orbital(-1.0) assert orb.R == orb.copy().R def test_psi1(self): # Orbital does not have radial part with pytest.raises(NotImplementedError): - Orbital(1.).psi(np.arange(10)) + Orbital(1.0).psi(np.arange(10)) def test_scale1(self): - o = Orbital(1.) - assert o.scale(2).R == 2. + o = Orbital(1.0) + assert o.scale(2).R == 2.0 o = Orbital(-1) - assert o.scale(2).R == -1. + assert o.scale(2).R == -1.0 def test_pickle1(self): import pickle as p - o0 = Orbital(1.) - o1 = Orbital(1., tag='none') + + o0 = Orbital(1.0) + o1 = Orbital(1.0, tag="none") p0 = p.dumps(o0) p1 = p.dumps(o1) l0 = p.loads(p0) @@ -91,16 +92,17 @@ def test_pickle1(self): class Test_sphericalorbital: - def test_init1(self): n = 6 rf = np.arange(n) rf = (rf, rf) assert SphericalOrbital(1, rf) == SphericalOrbital(1, rf) - f = interp.interp1d(rf[0], rf[1], fill_value=(0., 0.), bounds_error=False, kind='cubic') + f = interp.interp1d( + rf[0], rf[1], fill_value=(0.0, 0.0), bounds_error=False, kind="cubic" + ) rf = [rf[0], rf[0]] assert SphericalOrbital(1, rf) == SphericalOrbital(1, f) - assert SphericalOrbital(1, rf, tag='none') != SphericalOrbital(1, rf) + assert SphericalOrbital(1, rf, tag="none") != SphericalOrbital(1, rf) SphericalOrbital(5, rf) for l in range(10): o = SphericalOrbital(l, rf) @@ -111,14 +113,14 @@ def test_basic1(self): rf = r_f(6) orb = SphericalOrbital(1, rf) str(orb) - orb = SphericalOrbital(1, rf, tag='none') + orb = SphericalOrbital(1, rf, tag="none") str(orb) def test_copy(self): rf = r_f(6) - orb = SphericalOrbital(1, rf, R=2.) + orb = SphericalOrbital(1, rf, R=2.0) assert orb.R == orb.copy().R - assert orb.R == pytest.approx(2.) + assert orb.R == pytest.approx(2.0) orb = SphericalOrbital(1, rf) assert orb.R == orb.copy().R @@ -126,7 +128,7 @@ def test_set_radial1(self): rf = r_f(6) o = SphericalOrbital(1, rf) with pytest.raises(ValueError): - o.set_radial(1.) + o.set_radial(1.0) def test_set_radial_none(self): rf = r_f(6) @@ -145,10 +147,10 @@ def test_radial1(self): r0 = orb0.radial(r) r1 = orb1.radial(r) rr = np.stack((r, np.zeros(len(r)), np.zeros(len(r))), axis=1) - r2 = orb1.radial((rr ** 2).sum(-1) ** 0.5) + r2 = orb1.radial((rr**2).sum(-1) ** 0.5) assert np.allclose(r0, r1) assert np.allclose(r0, r2) - r[r >= rf[0].max()] = 0. + r[r >= rf[0].max()] = 0.0 assert np.allclose(r0, r) assert np.allclose(r1, r) @@ -174,14 +176,20 @@ def test_psi1(self): def test_radial_func1(self): r = np.linspace(0, 4, 300) f = np.exp(-r) - o = SphericalOrbital(1, (r, f), R=4.) + o = SphericalOrbital(1, (r, f), R=4.0) str(o) + def i_univariate(r, f): return interp.UnivariateSpline(r, f, k=5, s=0, ext=1, check_finite=False) + def i_interp1d(r, f): - return interp.interp1d(r, f, kind='cubic', fill_value=(f[0], 0.), bounds_error=False) + return interp.interp1d( + r, f, kind="cubic", fill_value=(f[0], 0.0), bounds_error=False + ) + def i_spline(r, f): from functools import partial + tck = interp.splrep(r, f, k=5, s=0) return partial(interp.splev, tck=tck, der=0, ext=1) @@ -212,7 +220,7 @@ def test_same1(self): o0 = SphericalOrbital(0, rf) o1 = Orbital(o0.R) assert o0.equal(o1) - assert not o0.equal(Orbital(3.)) + assert not o0.equal(Orbital(3.0)) def test_toatomicorbital1(self): rf = r_f(6) @@ -220,7 +228,7 @@ def test_toatomicorbital1(self): for l in range(_max_l + 1): orb = SphericalOrbital(l, rf) ao = orb.toAtomicOrbital() - assert len(ao) == 2*l + 1 + assert len(ao) == 2 * l + 1 m = -l for a in ao: assert a.l == orb.l @@ -252,17 +260,18 @@ def test_toatomicorbital2(self): def test_toatomicorbital_q0(self): rf = r_f(6) - orb = SphericalOrbital(0, rf, 2.) + orb = SphericalOrbital(0, rf, 2.0) # Check m and l for l in range(_max_l + 1): - orb = SphericalOrbital(l, rf, 2.) + orb = SphericalOrbital(l, rf, 2.0) ao = orb.toAtomicOrbital() - assert ao[0].q0 == pytest.approx(2. / (2*l+1)) + assert ao[0].q0 == pytest.approx(2.0 / (2 * l + 1)) def test_pickle1(self): rf = r_f(6) import pickle as p + o0 = SphericalOrbital(1, rf) o1 = SphericalOrbital(2, rf) p0 = p.dumps(o0) @@ -286,23 +295,24 @@ def test_togrid2(self): class Test_atomicorbital: - def test_init1(self): rf = r_f(6) a = [] a.append(AtomicOrbital(2, 1, 0, 1, True, rf)) a.append(AtomicOrbital(l=1, m=0, zeta=1, P=True, spherical=rf)) - f = interp.interp1d(rf[0], rf[1], fill_value=(0., 0.), bounds_error=False, kind='cubic') + f = interp.interp1d( + rf[0], rf[1], fill_value=(0.0, 0.0), bounds_error=False, kind="cubic" + ) a.append(AtomicOrbital(l=1, m=0, zeta=1, P=True, spherical=f)) - a.append(AtomicOrbital('pzP', f)) - a.append(AtomicOrbital('pzP', rf)) - a.append(AtomicOrbital('2pzP', rf)) + a.append(AtomicOrbital("pzP", f)) + a.append(AtomicOrbital("pzP", rf)) + a.append(AtomicOrbital("2pzP", rf)) for i in range(len(a) - 1): - for j in range(i+1, len(a)): + for j in range(i + 1, len(a)): assert a[i] == a[j] and a[i].equal(a[j], psi=True, radial=True) def test_init2(self): - assert AtomicOrbital('pzP') == AtomicOrbital(n=2, l=1, m=0, P=True) + assert AtomicOrbital("pzP") == AtomicOrbital(n=2, l=1, m=0, P=True) def test_init3(self): rf = r_f(6) @@ -311,7 +321,7 @@ def test_init3(self): a.name() a.name(True) str(a) - a = AtomicOrbital(l=l, m=0, P=True, spherical=rf, tag='hello') + a = AtomicOrbital(l=l, m=0, P=True, spherical=rf, tag="hello") a.name() a.name(True) str(a) @@ -319,10 +329,10 @@ def test_init3(self): def test_init4(self): rf = r_f(6) o1 = AtomicOrbital(2, 1, 0, 1, True, rf) - o2 = AtomicOrbital('pzP', rf) - o3 = AtomicOrbital('pzZP', rf) - o4 = AtomicOrbital('pzZ1P', rf) - o5 = AtomicOrbital('2pzZ1P', rf) + o2 = AtomicOrbital("pzP", rf) + o3 = AtomicOrbital("pzZP", rf) + o4 = AtomicOrbital("pzZ1P", rf) + o5 = AtomicOrbital("2pzZ1P", rf) assert o1 == o2 assert o1 == o3 assert o1 == o4 @@ -334,10 +344,10 @@ def test_init5(self): def test_copy(self): rf = r_f(6) - orb = AtomicOrbital('pzP', rf, R=2.) + orb = AtomicOrbital("pzP", rf, R=2.0) assert orb.R == orb.copy().R - assert orb.R == pytest.approx(2.) - orb = AtomicOrbital('pzP', rf) + assert orb.R == pytest.approx(2.0) + orb = AtomicOrbital("pzP", rf) assert orb.R == orb.copy().R def test_radial1(self): @@ -346,7 +356,7 @@ def test_radial1(self): for l in range(_max_l + 1): so = SphericalOrbital(l, rf) sor = so.radial(r) - for m in range(-l, l+1): + for m in range(-l, l + 1): o = AtomicOrbital(l=l, m=m, spherical=rf) assert np.allclose(sor, o.radial(r)) o.set_radial(rf[0], rf[1]) @@ -357,14 +367,15 @@ def test_phi1(self): r = np.linspace(0, 6, 999).reshape(-1, 3) for l in range(_max_l + 1): so = SphericalOrbital(l, rf) - for m in range(-l, l+1): + for m in range(-l, l + 1): o = AtomicOrbital(l=l, m=m, spherical=rf) assert np.allclose(so.psi(r, m), o.psi(r)) def test_pickle1(self): import pickle as p + rf = r_f(6) - o0 = AtomicOrbital(2, 1, 0, 1, True, rf, tag='hello', q0=1.) + o0 = AtomicOrbital(2, 1, 0, 1, True, rf, tag="hello", q0=1.0) o1 = AtomicOrbital(l=1, m=0, zeta=1, P=False, spherical=rf) o2 = AtomicOrbital(l=1, m=0, zeta=1, P=False) p0 = p.dumps(o0) @@ -382,19 +393,18 @@ def test_pickle1(self): class Test_hydrogenicorbital: - def test_init(self): orb = HydrogenicOrbital(2, 1, 0, 3.2) def test_basic1(self): - orb = HydrogenicOrbital(2, 1, 0, 3.2, R=4.) + orb = HydrogenicOrbital(2, 1, 0, 3.2, R=4.0) assert orb.R == orb.copy().R - assert orb.R == pytest.approx(4.) + assert orb.R == pytest.approx(4.0) orb = HydrogenicOrbital(2, 1, 0, 3.2) assert orb.R == orb.copy().R def test_copy(self): - orb = HydrogenicOrbital(2, 1, 0, 3.2, tag='test', q0=2.5) + orb = HydrogenicOrbital(2, 1, 0, 3.2, tag="test", q0=2.5) orb2 = orb.copy() assert orb.n == orb2.n assert orb.l == orb2.l @@ -409,7 +419,7 @@ def test_normalization(self): orb = HydrogenicOrbital(n, l, 0, zeff) x = np.linspace(0, orb.R, 1000, endpoint=True) Rnl = orb.radial(x) - I = np.trapz(x ** 2 * Rnl ** 2, x=x) + I = np.trapz(x**2 * Rnl**2, x=x) assert abs(I - 1) < 1e-4 def test_togrid(self): @@ -419,12 +429,13 @@ def test_togrid(self): for m in range(-l, l + 1): orb = HydrogenicOrbital(n, l, m, zeff) g = orb.toGrid(0.1) - I = (g.grid ** 2).sum() * g.dvolume + I = (g.grid**2).sum() * g.dvolume assert abs(I - 1) < 1e-3 def test_pickle(self): import pickle as p - o0 = HydrogenicOrbital(2, 1, 0, 3.2, tag='test', q0=2.5) + + o0 = HydrogenicOrbital(2, 1, 0, 3.2, tag="test", q0=2.5) o1 = HydrogenicOrbital(2, 1, 0, 3.2) p0 = p.dumps(o0) p1 = p.dumps(o1) @@ -437,7 +448,6 @@ def test_pickle(self): class Test_GTO: - def test_init(self): alpha = [1, 2] coeff = [0.1, 0.44] @@ -447,9 +457,9 @@ def test_init(self): def test_copy(self): alpha = [1, 2] coeff = [0.1, 0.44] - orb = GTOrbital(2, 1, 0, alpha, coeff, R=4.) + orb = GTOrbital(2, 1, 0, alpha, coeff, R=4.0) assert orb.R == orb.copy().R - assert orb.R == pytest.approx(4.) + assert orb.R == pytest.approx(4.0) orb = GTOrbital(2, 1, 0, alpha, coeff) assert orb.R == orb.copy().R @@ -471,7 +481,6 @@ def test_gto_funcs(self): class Test_STO: - def test_init(self): alpha = [1, 2] coeff = [0.1, 0.44] @@ -481,9 +490,9 @@ def test_init(self): def test_copy(self): alpha = [1, 2] coeff = [0.1, 0.44] - orb = STOrbital(2, 1, 0, alpha, coeff, R=4.) + orb = STOrbital(2, 1, 0, alpha, coeff, R=4.0) assert orb.R == orb.copy().R - assert orb.R == pytest.approx(4.) + assert orb.R == pytest.approx(4.0) orb = STOrbital(2, 1, 0, alpha, coeff) assert orb.R == orb.copy().R diff --git a/src/sisl/tests/test_plot.py b/src/sisl/tests/test_plot.py index c42fcad55e..eae55fdb7a 100644 --- a/src/sisl/tests/test_plot.py +++ b/src/sisl/tests/test_plot.py @@ -8,9 +8,9 @@ pytestmark = pytest.mark.plot -mlib = pytest.importorskip('matplotlib') -plt = pytest.importorskip('matplotlib.pyplot') -mlib3d = pytest.importorskip('mpl_toolkits.mplot3d') +mlib = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") +mlib3d = pytest.importorskip("mpl_toolkits.mplot3d") def test_supercell_2d(): @@ -18,17 +18,17 @@ def test_supercell_2d(): sisl.plot(g.lattice, axis=[0, 1]) sisl.plot(g.lattice, axis=[0, 2]) sisl.plot(g.lattice, axis=[1, 2]) - plt.close('all') + plt.close("all") ax = plt.subplot(111) sisl.plot(g.lattice, axis=[1, 2], axes=ax) - plt.close('all') + plt.close("all") def test_supercell_3d(): g = sisl.geom.graphene() sisl.plot(g.lattice) - plt.close('all') + plt.close("all") def test_geometry_2d(): @@ -36,11 +36,11 @@ def test_geometry_2d(): sisl.plot(g, axis=[0, 1]) sisl.plot(g, axis=[0, 2]) sisl.plot(g, axis=[1, 2]) - plt.close('all') + plt.close("all") ax = plt.subplot(111) sisl.plot(g, axis=[1, 2], axes=ax) - plt.close('all') + plt.close("all") def test_geometry_2d_atom_indices(): @@ -48,48 +48,49 @@ def test_geometry_2d_atom_indices(): sisl.plot(g, axis=[0, 1]) sisl.plot(g, axis=[0, 2]) sisl.plot(g, axis=[1, 2]) - plt.close('all') + plt.close("all") ax = plt.subplot(111) sisl.plot(g, axis=[1, 2], axes=ax, atom_indices=True) - plt.close('all') + plt.close("all") def test_geometry_3d(): g = sisl.geom.graphene() sisl.plot(g) - plt.close('all') + plt.close("all") def test_geometry_3d_atom_indices(): g = sisl.geom.graphene() sisl.plot(g, atom_indices=True) - plt.close('all') + plt.close("all") def test_orbital_radial(): r = np.linspace(0, 10, 1000) - f = np.exp(- r) + f = np.exp(-r) o = sisl.SphericalOrbital(2, (r, f)) sisl.plot(o) - plt.close('all') + plt.close("all") fig = plt.figure() sisl.plot(o, axes=fig.gca()) - plt.close('all') + plt.close("all") def test_orbital_harmonics(): r = np.linspace(0, 10, 1000) - f = np.exp(- r) + f = np.exp(-r) o = sisl.SphericalOrbital(2, (r, f)) sisl.plot(o, harmonics=True) - plt.close('all') + plt.close("all") def test_not_implemented(): class Test: pass + t = Test() with pytest.raises(NotImplementedError): sisl.plot(t) diff --git a/src/sisl/tests/test_quaternion.py b/src/sisl/tests/test_quaternion.py index fa384b7de4..27b96772fa 100644 --- a/src/sisl/tests/test_quaternion.py +++ b/src/sisl/tests/test_quaternion.py @@ -11,7 +11,7 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): self.qx = Quaternion(90, [1, 0, 0]) self.qy = Quaternion(90, [0, 1, 0]) @@ -19,12 +19,12 @@ def __init__(self): self.Qx = Quaternion(90, [2, 0, 0]) self.Qy = Quaternion(90, [0, 2, 0]) self.Qz = Quaternion(90, [0, 0, 2]) + return t() @pytest.mark.quaternion class TestQuaternion: - def test_copy(self, setup): qx = setup.qx.copy() assert qx == setup.qx @@ -34,12 +34,12 @@ def test_conj(self, setup): assert qx.conj() == setup.qx def test_norm(self, setup): - for c in 'xyz': - assert getattr(setup, 'q'+c).norm() == 1. + for c in "xyz": + assert getattr(setup, "q" + c).norm() == 1.0 def test_degree1(self, setup): - for c in 'xyz': - assert getattr(setup, 'q'+c).degree == 90 + for c in "xyz": + assert getattr(setup, "q" + c).degree == 90 def test_radians1(self, setup): rx = setup.qx.radian @@ -63,21 +63,21 @@ def test_op1(self, setup): assert -rx == setup.qx def test_op2(self, setup): - rx = setup.qx + 1. - assert rx - 1. == setup.qx + rx = setup.qx + 1.0 + assert rx - 1.0 == setup.qx - rx = setup.qx * 1. + rx = setup.qx * 1.0 assert rx == setup.qx - rx = setup.qx * 1. + rx = setup.qx * 1.0 assert rx == setup.qx - rx = setup.qx / 1. + rx = setup.qx / 1.0 assert rx == setup.qx rx = setup.qx.copy() - rx += 1. - rx -= 1. + rx += 1.0 + rx -= 1.0 assert rx == setup.qx rx = setup.qx.copy() diff --git a/src/sisl/tests/test_selector.py b/src/sisl/tests/test_selector.py index 009ca84ebe..3e9fd4ef57 100644 --- a/src/sisl/tests/test_selector.py +++ b/src/sisl/tests/test_selector.py @@ -11,20 +11,22 @@ @pytest.mark.selector -@pytest.mark.xfail(sys.platform.startswith("darwin"), - reason="Sleep on MacOS is not consistent causing erroneous fails.") +@pytest.mark.xfail( + sys.platform.startswith("darwin"), + reason="Sleep on MacOS is not consistent causing erroneous fails.", +) class TestSelector: - def sleep(self, *args): if len(args) == 1: + def _sleep(): - time.sleep(1. / 100 * args[0]) + time.sleep(1.0 / 100 * args[0]) + _sleep.__name__ = str(args[0]) return _sleep return [self.sleep(arg) for arg in args] def test_selector1(self): - sel = TimeSelector() sel.prepend(self.sleep(1)) sel.prepend(self.sleep(2)) diff --git a/src/sisl/tests/test_sgeom.py b/src/sisl/tests/test_sgeom.py index 8f718f4d9c..3a2870a4ea 100644 --- a/src/sisl/tests/test_sgeom.py +++ b/src/sisl/tests/test_sgeom.py @@ -11,141 +11,147 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) C = Atom(Z=6, R=[bond * 1.01] * 2) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.mol = Geometry([[i, 0, 0] for i in range(10)], lattice=[50]) def sg_g(**kwargs): - kwargs['ret_geometry'] = True - if 'geometry' not in kwargs: - kwargs['geometry'] = self.g.copy() + kwargs["ret_geometry"] = True + if "geometry" not in kwargs: + kwargs["geometry"] = self.g.copy() return sgeom(**kwargs) self.sg_g = sg_g def sg_mol(**kwargs): - kwargs['ret_geometry'] = True - if 'geometry' not in kwargs: - kwargs['geometry'] = self.mol.copy() + kwargs["ret_geometry"] = True + if "geometry" not in kwargs: + kwargs["geometry"] = self.mol.copy() return sgeom(**kwargs) self.sg_mol = sg_mol + return t() @pytest.mark.geometry class TestGeometry: - def test_help(self): with pytest.raises(SystemExit): - sgeom(argv=['--help']) + sgeom(argv=["--help"]) def test_version(self): - sgeom(argv=['--version']) + sgeom(argv=["--version"]) def test_cite(self): - sgeom(argv=['--cite']) + sgeom(argv=["--cite"]) def test_tile1(self, setup): cell = np.copy(setup.g.lattice.cell) cell[0, :] *= 2 - for tile in ['tile 2 x', 'tile-x 2']: - tx = setup.sg_g(argv=('--' + tile).split()) + for tile in ["tile 2 x", "tile-x 2"]: + tx = setup.sg_g(argv=("--" + tile).split()) assert np.allclose(cell, tx.lattice.cell) cell[1, :] *= 2 - for tile in ['tile 2 y', 'tile-y 2']: - ty = setup.sg_g(geometry=tx, argv=('--' + tile).split()) + for tile in ["tile 2 y", "tile-y 2"]: + ty = setup.sg_g(geometry=tx, argv=("--" + tile).split()) assert np.allclose(cell, ty.lattice.cell) cell[2, :] *= 2 - for tile in ['tile 2 z', 'tile-z 2']: - tz = setup.sg_g(geometry=ty, argv=('--' + tile).split()) + for tile in ["tile 2 z", "tile-z 2"]: + tz = setup.sg_g(geometry=ty, argv=("--" + tile).split()) assert np.allclose(cell, tz.lattice.cell) def test_tile2(self, setup): cell = np.copy(setup.g.lattice.cell) cell[:, :] *= 2 - for xt in ['tile 2 x', 'tile-x 2']: - xt = '--' + xt - for yt in ['tile 2 y', 'tile-y 2']: - yt = '--' + yt - for zt in ['tile 2 z', 'tile-z 2']: - zt = '--' + zt - argv = ' '.join([xt, yt, zt]).split() + for xt in ["tile 2 x", "tile-x 2"]: + xt = "--" + xt + for yt in ["tile 2 y", "tile-y 2"]: + yt = "--" + yt + for zt in ["tile 2 z", "tile-z 2"]: + zt = "--" + zt + argv = " ".join([xt, yt, zt]).split() t = setup.sg_g(argv=argv) assert np.allclose(cell, t.lattice.cell) def test_repeat1(self, setup): cell = np.copy(setup.g.lattice.cell) cell[0, :] *= 2 - for repeat in ['repeat 2 x', 'repeat-x 2']: - tx = setup.sg_g(argv=('--' + repeat).split()) + for repeat in ["repeat 2 x", "repeat-x 2"]: + tx = setup.sg_g(argv=("--" + repeat).split()) assert np.allclose(cell, tx.lattice.cell) cell[1, :] *= 2 - for repeat in ['repeat 2 y', 'repeat-y 2']: - ty = setup.sg_g(geometry=tx, argv=('--' + repeat).split()) + for repeat in ["repeat 2 y", "repeat-y 2"]: + ty = setup.sg_g(geometry=tx, argv=("--" + repeat).split()) assert np.allclose(cell, ty.lattice.cell) cell[2, :] *= 2 - for repeat in ['repeat 2 z', 'repeat-z 2']: - tz = setup.sg_g(geometry=ty, argv=('--' + repeat).split()) + for repeat in ["repeat 2 z", "repeat-z 2"]: + tz = setup.sg_g(geometry=ty, argv=("--" + repeat).split()) assert np.allclose(cell, tz.lattice.cell) def test_repeat2(self, setup): cell = np.copy(setup.g.lattice.cell) cell[:, :] *= 2 - for xt in ['repeat 2 x', 'repeat-x 2']: - xt = '--' + xt - for yt in ['repeat 2 y', 'repeat-y 2']: - yt = '--' + yt - for zt in ['repeat 2 z', 'repeat-z 2']: - zt = '--' + zt - argv = ' '.join([xt, yt, zt]).split() + for xt in ["repeat 2 x", "repeat-x 2"]: + xt = "--" + xt + for yt in ["repeat 2 y", "repeat-y 2"]: + yt = "--" + yt + for zt in ["repeat 2 z", "repeat-z 2"]: + zt = "--" + zt + argv = " ".join([xt, yt, zt]).split() t = setup.sg_g(argv=argv) assert np.allclose(cell, t.lattice.cell) def test_sub(self, setup): - for a, l in [('0', 1), ('0,1', 2), ('0-1', 2)]: - g = setup.sg_g(argv=['--sub', a]) + for a, l in [("0", 1), ("0,1", 2), ("0-1", 2)]: + g = setup.sg_g(argv=["--sub", a]) assert len(g) == l def test_remove(self, setup): geom = setup.g.tile(2, 0).tile(2, 1) N = len(geom) - for a, l in [('0', 1), ('0,1', 2), ('0-1', 2)]: - g = setup.sg_g(geometry=geom.copy(), argv=['--remove', a]) + for a, l in [("0", 1), ("0,1", 2), ("0-1", 2)]: + g = setup.sg_g(geometry=geom.copy(), argv=["--remove", a]) assert len(g) == N - l def test_rotation1(self, setup): - rot = setup.sg_g(argv='--rotate 180 z'.split()) + rot = setup.sg_g(argv="--rotate 180 z".split()) rot.lattice.cell[2, 2] *= -1 assert np.allclose(-rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(-rot.xyz, setup.g.xyz) - rot = setup.sg_g(argv='--rotate-z 180'.split()) + rot = setup.sg_g(argv="--rotate-z 180".split()) rot.lattice.cell[2, 2] *= -1 assert np.allclose(-rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(-rot.xyz, setup.g.xyz) - rot = setup.sg_g(argv='--rotate rpi z'.split()) + rot = setup.sg_g(argv="--rotate rpi z".split()) rot.lattice.cell[2, 2] *= -1 assert np.allclose(-rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(-rot.xyz, setup.g.xyz) - rot = setup.sg_g(argv='--rotate-z rpi'.split()) + rot = setup.sg_g(argv="--rotate-z rpi".split()) rot.lattice.cell[2, 2] *= -1 assert np.allclose(-rot.lattice.cell, setup.g.lattice.cell) assert np.allclose(-rot.xyz, setup.g.xyz) def test_swap(self, setup): - s = setup.sg_g(argv='--swap 0 1'.split()) + s = setup.sg_g(argv="--swap 0 1".split()) for i in [0, 1, 2]: assert np.allclose(setup.g.xyz[::-1, i], s.xyz[:, i]) diff --git a/src/sisl/tests/test_sgrid.py b/src/sisl/tests/test_sgrid.py index 3041682d96..7e494df92c 100644 --- a/src/sisl/tests/test_sgrid.py +++ b/src/sisl/tests/test_sgrid.py @@ -10,22 +10,28 @@ from sisl import Atom, Geometry, Grid, Lattice, get_sile from sisl.grid import sgrid -_dir = osp.join('sisl') +_dir = osp.join("sisl") @pytest.fixture def setup(): - class t(): + class t: def __init__(self): bond = 1.42 - sq3h = 3.**.5 * 0.5 - self.lattice = Lattice(np.array([[1.5, sq3h, 0.], - [1.5, -sq3h, 0.], - [0., 0., 10.]], np.float64) * bond, nsc=[3, 3, 1]) + sq3h = 3.0**0.5 * 0.5 + self.lattice = Lattice( + np.array( + [[1.5, sq3h, 0.0], [1.5, -sq3h, 0.0], [0.0, 0.0, 10.0]], np.float64 + ) + * bond, + nsc=[3, 3, 1], + ) C = Atom(Z=6, R=[bond * 1.01] * 2) - self.g = Geometry(np.array([[0., 0., 0.], - [1., 0., 0.]], np.float64) * bond, - atoms=C, lattice=self.lattice) + self.g = Geometry( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], np.float64) * bond, + atoms=C, + lattice=self.lattice, + ) self.grid = Grid(0.2, geometry=self.g) self.grid.grid[:, :, :] = np.random.rand(*self.grid.shape) @@ -35,160 +41,160 @@ def __init__(self): self.grid_mol.grid[:, :, :] = np.random.rand(*self.grid_mol.shape) def sg_g(**kwargs): - kwargs['ret_grid'] = True - if 'grid' not in kwargs: - kwargs['grid'] = self.grid + kwargs["ret_grid"] = True + if "grid" not in kwargs: + kwargs["grid"] = self.grid return sgrid(**kwargs) self.sg_g = sg_g def sg_mol(**kwargs): - kwargs['ret_grid'] = True - if 'grid' not in kwargs: - kwargs['grid'] = self.grid_mol + kwargs["ret_grid"] = True + if "grid" not in kwargs: + kwargs["grid"] = self.grid_mol return sgrid(**kwargs) self.sg_mol = sg_mol + return t() @pytest.mark.sgrid class TestsGrid: - def test_help(self): with pytest.raises(SystemExit): - sgrid(argv=['--help']) + sgrid(argv=["--help"]) def test_version(self): - sgrid(argv=['--version']) + sgrid(argv=["--version"]) def test_cite(self): - sgrid(argv=['--cite']) + sgrid(argv=["--cite"]) def test_average1(self, setup): g = setup.grid.copy() gavg = g.average(0) - for avg in ['average x', 'average 0']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["average x", "average 0"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) gavg = g.average(1) - for avg in ['average y', 'average 1']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["average y", "average 1"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) gavg = g.average(2) - for avg in ['average z', 'average 2']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["average z", "average 2"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) def test_average2(self, setup): g = setup.grid.copy() gavg = g.average(0).average(1) - for avg in ['average x --average 1', 'average 0 --average y']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["average x --average 1", "average 0 --average y"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) def test_sum1(self, setup): g = setup.grid.copy() gavg = g.sum(0) - for avg in ['sum x', 'sum 0']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["sum x", "sum 0"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) gavg = g.sum(1) - for avg in ['sum y', 'sum 1']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["sum y", "sum 1"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) gavg = g.sum(2) - for avg in ['sum z', 'sum 2']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["sum z", "sum 2"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) def test_sum2(self, setup): g = setup.grid.copy() gavg = g.sum(0).sum(1) - for avg in ['sum x --sum 1', 'sum 0 --sum y']: - G = setup.sg_g(argv=('--' + avg).split()) + for avg in ["sum x --sum 1", "sum 0 --sum y"]: + G = setup.sg_g(argv=("--" + avg).split()) assert np.allclose(G.grid, gavg.grid) def test_print1(self, setup): - setup.sg_g(argv=['--info']) + setup.sg_g(argv=["--info"]) def test_interp(self, setup): - g1 = setup.sg_g(argv='--interp 10 10 10'.split()) + g1 = setup.sg_g(argv="--interp 10 10 10".split()) # last argument is default - g2 = setup.sg_g(argv='--interp 10 10 10 1'.split()) + g2 = setup.sg_g(argv="--interp 10 10 10 1".split()) assert np.allclose(g1.grid, g2.grid) - g2 = setup.sg_g(argv='--interp 10 10 10 3'.split()) + g2 = setup.sg_g(argv="--interp 10 10 10 3".split()) assert not np.allclose(g1.grid, g2.grid) - g3 = setup.sg_g(argv='--interp 0.01 0.1 1. 3'.split()) + g3 = setup.sg_g(argv="--interp 0.01 0.1 1. 3".split()) def test_smooth(self, setup): - g1 = setup.sg_g(argv=['--smooth']) - g2 = setup.sg_g(argv='--smooth 0.7'.split()) + g1 = setup.sg_g(argv=["--smooth"]) + g2 = setup.sg_g(argv="--smooth 0.7".split()) assert np.allclose(g1.grid, g2.grid) - g2 = setup.sg_g(argv='--smooth 1.'.split()) + g2 = setup.sg_g(argv="--smooth 1.".split()) assert not np.allclose(g1.grid, g2.grid) def test_sub1(self, setup): g = setup.grid.copy() - idx = g.index(1., 0) + idx = g.index(1.0, 0) gs = g.sub_part(idx, 0, True) - for sub in ['sub 1.: a', 'sub 1.: 0']: - G = setup.sg_g(argv=('--' + sub).split()) + for sub in ["sub 1.: a", "sub 1.: 0"]: + G = setup.sg_g(argv=("--" + sub).split()) assert np.allclose(G.grid, gs.grid) idx = g.index(1.2, 1) gs = g.sub_part(idx, 1, True) - for sub in ['sub 1.2: b', 'sub 1.2: y']: - G = setup.sg_g(argv=('--' + sub).split()) + for sub in ["sub 1.2: b", "sub 1.2: y"]: + G = setup.sg_g(argv=("--" + sub).split()) assert np.allclose(G.grid, gs.grid) idx = g.index(0.8, 2) gs = g.sub_part(idx, 2, False) - for sub in ['sub :.8 2', 'sub :0.8 z']: - G = setup.sg_g(argv=('--' + sub).split()) + for sub in ["sub :.8 2", "sub :0.8 z"]: + G = setup.sg_g(argv=("--" + sub).split()) assert np.allclose(G.grid, gs.grid) def test_remove1(self, setup): g = setup.grid.copy() - idx = g.index(1., 0) + idx = g.index(1.0, 0) gs = g.remove_part(idx, 0, True) - for remove in ['remove 1.: a', 'remove 1.: 0']: - G = setup.sg_g(argv=('--' + remove).split()) + for remove in ["remove 1.: a", "remove 1.: 0"]: + G = setup.sg_g(argv=("--" + remove).split()) assert np.allclose(G.grid, gs.grid) idx = g.index(1.2, 1) gs = g.remove_part(idx, 1, False) - for remove in ['remove :1.2 b', 'remove :1.2 y']: - G = setup.sg_g(argv=('--' + remove).split()) + for remove in ["remove :1.2 b", "remove :1.2 y"]: + G = setup.sg_g(argv=("--" + remove).split()) assert np.allclose(G.grid, gs.grid) idx = g.index(0.8, 2) gs = g.remove_part(idx, 2, False) - for remove in ['remove :.8 2', 'remove :0.8 z']: - G = setup.sg_g(argv=('--' + remove).split()) + for remove in ["remove :.8 2", "remove :0.8 z"]: + G = setup.sg_g(argv=("--" + remove).split()) assert np.allclose(G.grid, gs.grid) def test_tile1(self, setup): g = setup.grid.copy() g2 = g.tile(2, 0) - G = setup.sg_g(argv='--tile 2 a'.split()) + G = setup.sg_g(argv="--tile 2 a".split()) assert np.allclose(G.grid, g2.grid) g2 = g.tile(2, 1) - G = setup.sg_g(argv='--tile 2 y'.split()) + G = setup.sg_g(argv="--tile 2 y".split()) assert np.allclose(G.grid, g2.grid) g2 = g.tile(2, 2) - G = setup.sg_g(argv='--tile 2 c'.split()) + G = setup.sg_g(argv="--tile 2 c".split()) assert np.allclose(G.grid, g2.grid) def test_write_data(self, setup, sisl_tmp): - out = sisl_tmp('table.dat', _dir) - G = setup.sg_g(argv=f'--sum 0 --average 1 --out {out}'.split()) + out = sisl_tmp("table.dat", _dir) + G = setup.sg_g(argv=f"--sum 0 --average 1 --out {out}".split()) dat = get_sile(out).read_data() assert np.allclose(dat[1, :], G.grid.ravel()) def test_write_grid(self, setup, sisl_tmp): - out = sisl_tmp('table.cube', _dir) - G = setup.sg_g(argv=f'--sum 0 --out {out}'.split()) + out = sisl_tmp("table.cube", _dir) + G = setup.sg_g(argv=f"--sum 0 --out {out}".split()) diff --git a/src/sisl/tests/test_sparse.py b/src/sisl/tests/test_sparse.py index d4e64e7498..40a933f1df 100644 --- a/src/sisl/tests/test_sparse.py +++ b/src/sisl/tests/test_sparse.py @@ -12,16 +12,20 @@ from sisl.sparse import * from sisl.sparse import indices -pytestmark = [pytest.mark.sparse, pytest.mark.filterwarnings("ignore", category=sc.sparse.SparseEfficiencyWarning)] +pytestmark = [ + pytest.mark.sparse, + pytest.mark.filterwarnings("ignore", category=sc.sparse.SparseEfficiencyWarning), +] @pytest.fixture def setup(): - class t(): + class t: def __init__(self): self.s1 = SparseCSR((10, 100), dtype=np.int32) self.s1d = SparseCSR((10, 100)) self.s2 = SparseCSR((10, 100, 2)) + return t() @@ -231,9 +235,9 @@ def test_create2(setup): s1 = setup.s1 assert len(s1) == s1.shape[0] for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) s1[0, j] = i - assert s1.nnz == (i+1)*3 + assert s1.nnz == (i + 1) * 3 for jj in j: assert s1[0, jj] == i assert s1[1, jj] == 0 @@ -243,11 +247,11 @@ def test_create2(setup): def test_create3(setup): s1 = setup.s1 for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) s1[0, j] = i - assert s1.nnz == (i+1)*3 - s1[0, range((i+1)*4, (i+1)*4+3)] = None - assert s1.nnz == (i+1)*3 + assert s1.nnz == (i + 1) * 3 + s1[0, range((i + 1) * 4, (i + 1) * 4 + 3)] = None + assert s1.nnz == (i + 1) * 3 for jj in j: assert s1[0, jj] == i assert s1[1, jj] == 0 @@ -257,7 +261,7 @@ def test_create3(setup): def test_create_1d_bcasting_data_1d(setup): s1 = setup.s1.copy() for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) s1[0, j] = i s1[1, j] = i s1[2, j] = i @@ -272,12 +276,14 @@ def test_create_1d_bcasting_data_1d(setup): assert np.sum(s1 - s2) == 0 -@pytest.mark.xfail(sys.platform.startswith("win"), reason="Unknown windows error in b-casting") +@pytest.mark.xfail( + sys.platform.startswith("win"), reason="Unknown windows error in b-casting" +) def test_create_1d_bcasting_data_2d(setup): s1 = setup.s1.copy() data = np.random.randint(1, 100, (4, 3)) for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) s1[0, j] = data[0, :] s1[1, j] = data[1, :] s1[2, j] = data[2, :] @@ -364,7 +370,7 @@ def test_create_2d_data_2d(setup): s1 = setup.s2.copy() # matrix assignment I = np.arange(len(s1) // 2) - data = np.random.randint(1, 100, I.size ** 2).reshape(I.size, I.size) + data = np.random.randint(1, 100, I.size**2).reshape(I.size, I.size) for i in I: s1[i, I, 0] = data[i] s1[i, I, 1] = data[i] @@ -387,7 +393,7 @@ def test_create_2d_data_3d(setup): s1 = setup.s2.copy() # matrix assignment I = np.arange(len(s1) // 2) - data = np.random.randint(1, 100, I.size ** 2 * 2).reshape(I.size, I.size, 2) + data = np.random.randint(1, 100, I.size**2 * 2).reshape(I.size, I.size, 2) for i in I: s1[i, I] = data[i] @@ -424,7 +430,7 @@ def test_fail_data_2d_to_2d(setup): s2 = setup.s2 # matrix assignment I = np.arange(len(s2) // 2).reshape(-1, 1) - data = np.random.randint(1, 100, I.size **2).reshape(I.size, I.size) + data = np.random.randint(1, 100, I.size**2).reshape(I.size, I.size) with pytest.raises(ValueError): s2[I, I.T] = data s2.empty() @@ -434,7 +440,7 @@ def test_fail_data_2d_to_3d(setup): s2 = setup.s2 # matrix assignment I = np.arange(len(s2) // 2).reshape(-1, 1) - data = np.random.randint(1, 100, I.size **2).reshape(I.size, I.size) + data = np.random.randint(1, 100, I.size**2).reshape(I.size, I.size) with pytest.raises(ValueError): s2[I, I.T, [0, 1]] = data s2.empty() @@ -443,17 +449,17 @@ def test_fail_data_2d_to_3d(setup): def test_finalize1(setup): s1 = setup.s1 s1[0, [1, 2, 3]] = 1 - s1[2, [1, 2, 3]] = 1. - s1[1, [3, 2, 1]] = 1. + s1[2, [1, 2, 3]] = 1.0 + s1[1, [3, 2, 1]] = 1.0 assert not s1.finalized p = s1.ptr.view() n = s1.ncol.view() # Assert that the ordering is good - assert np.allclose(s1.col[p[1]:p[1]+n[1]], [3, 2, 1]) + assert np.allclose(s1.col[p[1] : p[1] + n[1]], [3, 2, 1]) s1.finalize() # This also asserts that we do not change the memory-locations # of the pointers and ncol - assert np.allclose(s1.col[p[1]:p[1]+n[1]], [1, 2, 3]) + assert np.allclose(s1.col[p[1] : p[1] + n[1]], [1, 2, 3]) assert s1.finalized s1.empty(keep_nnz=True) assert s1.finalized @@ -464,17 +470,17 @@ def test_finalize1(setup): def test_finalize2(setup): s1 = setup.s1 s1[0, [1, 2, 3]] = 1 - s1[2, [1, 2, 3]] = 1. - s1[1, [3, 2, 1]] = 1. + s1[2, [1, 2, 3]] = 1.0 + s1[1, [3, 2, 1]] = 1.0 assert not s1.finalized p = s1.ptr.view() n = s1.ncol.view() # Assert that the ordering is good - assert np.allclose(s1.col[p[1]:p[1]+n[1]], [3, 2, 1]) + assert np.allclose(s1.col[p[1] : p[1] + n[1]], [3, 2, 1]) s1.finalize(False) # This also asserts that we do not change the memory-locations # of the pointers and ncol - assert np.allclose(s1.col[p[1]:p[1]+n[1]], [3, 2, 1]) + assert np.allclose(s1.col[p[1] : p[1] + n[1]], [3, 2, 1]) assert not s1.finalized assert len(s1.col) == 9 s1.empty() @@ -483,7 +489,7 @@ def test_finalize2(setup): def test_iterator1(setup): s1 = setup.s1 s1[0, [1, 2, 3]] = 1 - s1[2, [1, 2, 4]] = 1. + s1[2, [1, 2, 4]] = 1.0 e = [[1, 2, 3], [], [1, 2, 4]] for i, j in s1: assert j in e[i] @@ -498,14 +504,14 @@ def test_iterator1(setup): for i, j in ispmatrix(s1): assert j in e[i] - for i, j in ispmatrix(s1, map_col = lambda x: x): + for i, j in ispmatrix(s1, map_col=lambda x: x): assert j in e[i] - for i, j in ispmatrix(s1, map_row = lambda x: x): + for i, j in ispmatrix(s1, map_row=lambda x: x): assert j in e[i] for i, j, d in ispmatrixd(s1): assert j in e[i] - assert d == 1. + assert d == 1.0 s1.empty() @@ -514,20 +520,20 @@ def test_iterator2(setup): s1 = setup.s1 e = [[1, 2, 3], [], [1, 2, 4]] s1[0, [1, 2, 3]] = 1 - s1[2, [1, 2, 4]] = 1. + s1[2, [1, 2, 4]] = 1.0 a = s1.tocsr() - for func in ['csr', 'csc', 'coo', 'lil']: - a = getattr(a, 'to' + func)() + for func in ["csr", "csc", "coo", "lil"]: + a = getattr(a, "to" + func)() for r, c in ispmatrix(a): assert r in [0, 2] assert c in e[r] - for func in ['csr', 'csc', 'coo', 'lil']: - a = getattr(a, 'to' + func)() + for func in ["csr", "csc", "coo", "lil"]: + a = getattr(a, "to" + func)() for r, c, d in ispmatrixd(a): assert r in [0, 2] assert c in e[r] - assert d == 1. + assert d == 1.0 s1.empty() @@ -536,20 +542,20 @@ def test_iterator3(setup): s1 = setup.s1 e = [[1, 2, 3], [], [1, 2, 4]] s1[0, [1, 2, 3]] = 1 - s1[2, [1, 2, 4]] = 1. + s1[2, [1, 2, 4]] = 1.0 a = s1.tocsr() - for func in ['csr', 'csc', 'coo', 'lil']: - a = getattr(a, 'to' + func)() + for func in ["csr", "csc", "coo", "lil"]: + a = getattr(a, "to" + func)() for r, c in ispmatrix(a): assert r in [0, 2] assert c in e[r] # number of mapped values nvals = 2 - for func in ['csr', 'lil']: - a = getattr(a, 'to' + func)() + for func in ["csr", "lil"]: + a = getattr(a, "to" + func)() n = 0 - for r, c in ispmatrix(a, lambda x: x%2, lambda x: x%2): + for r, c in ispmatrix(a, lambda x: x % 2, lambda x: x % 2): assert r == 0 assert c in [0, 1] n += 1 @@ -722,7 +728,7 @@ def test_delete_col3(setup): s2[i, 2] = 1 s1.finalize() s2.finalize() - assert s1.nnz == 10*3 + assert s1.nnz == 10 * 3 s1.delete_columns([3, 1], keep_shape=True) assert s1.ptr[-1] == s1.nnz assert s2.ptr[-1] == s2.nnz @@ -739,7 +745,7 @@ def test_delete_col4(): s2[i, 1] = 1 s1.finalize() s2.finalize() - assert s1.nnz == 10*3 + assert s1.nnz == 10 * 3 s1.delete_columns([3, 1]) assert s1.ptr[-1] == s1.nnz assert s2.ptr[-1] == s2.nnz @@ -887,13 +893,13 @@ def test_nonzero1(setup): def test_op1(setup): s1 = setup.s1 for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) s1[0, j] = i # i+ s1 += 1 for jj in j: - assert s1[0, jj] == i+1 + assert s1[0, jj] == i + 1 assert s1[1, jj] == 0 # i- @@ -905,7 +911,7 @@ def test_op1(setup): # i* s1 *= 2 for jj in j: - assert s1[0, jj] == i*2 + assert s1[0, jj] == i * 2 assert s1[1, jj] == 0 # // @@ -925,20 +931,20 @@ def test_op1(setup): def test_op2(setup): s1 = setup.s1 for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) s1[0, j] = i # + s = s1 + 1 for jj in j: - assert s[0, jj] == i+1 + assert s[0, jj] == i + 1 assert s1[0, jj] == i assert s[1, jj] == 0 # - s = s1 - 1 for jj in j: - assert s[0, jj] == i-1 + assert s[0, jj] == i - 1 assert s1[0, jj] == i assert s[1, jj] == 0 @@ -952,21 +958,21 @@ def test_op2(setup): # * s = s1 * 2 for jj in j: - assert s[0, jj] == i*2 + assert s[0, jj] == i * 2 assert s1[0, jj] == i assert s[1, jj] == 0 # * s = np.multiply(s1, 2) for jj in j: - assert s[0, jj] == i*2 + assert s[0, jj] == i * 2 assert s1[0, jj] == i assert s[1, jj] == 0 # * s.empty() np.multiply(s1, 2, out=s) for jj in j: - assert s[0, jj] == i*2 + assert s[0, jj] == i * 2 assert s1[0, jj] == i assert s[1, jj] == 0 @@ -978,14 +984,14 @@ def test_op2(setup): assert s[1, jj] == 0 # ** - s = s1 ** 2 + s = s1**2 for jj in j: assert s[0, jj] == i**2 assert s1[0, jj] == i assert s[1, jj] == 0 # ** (r) - s = 2 ** s1 + s = 2**s1 for jj in j: assert s[0, jj], 2 ** s1[0 == jj] assert s1[0, jj] == i @@ -1005,7 +1011,8 @@ def test_op_csr(setup): # + s = s1 + csr for jj in j: - if jj == 0: continue + if jj == 0: + continue assert s[0, jj] == i assert s1[0, jj] == i assert s[1, jj] == 0 @@ -1014,7 +1021,8 @@ def test_op_csr(setup): # - s = s1 - csr for jj in j: - if jj == 0: continue + if jj == 0: + continue assert s[0, jj] == i assert s1[0, jj] == i assert s[1, jj] == 0 @@ -1023,7 +1031,8 @@ def test_op_csr(setup): # - (r) s = csr - s1 for jj in j: - if jj == 0: continue + if jj == 0: + continue assert s[0, jj] == -i assert s1[0, jj] == i assert s[1, jj] == 0 @@ -1034,7 +1043,8 @@ def test_op_csr(setup): # * s = s1 * csr for jj in j: - if jj == 0: continue + if jj == 0: + continue assert s[0, jj] == 0 assert s1[0, jj] == i assert s[1, jj] == 0 @@ -1045,13 +1055,14 @@ def test_op_csr(setup): assert s[0, 0] == i # ** - s = s1 ** csr + s = s1**csr for jj in j: - if jj == 0: continue + if jj == 0: + continue assert s[0, jj] == 1 assert s1[0, jj] == i assert s[1, jj] == 0 - assert s[0, 0] == i ** 2 + assert s[0, 0] == i**2 s1.empty() @@ -1059,39 +1070,39 @@ def test_op3(): S = SparseCSR((10, 100), dtype=np.int32) # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) S[0, j] = i - for op in ['add', 'sub', 'mul', 'pow']: - func = getattr(S, f'__{op}__') + for op in ["add", "sub", "mul", "pow"]: + func = getattr(S, f"__{op}__") s = func(1) assert s.dtype == np.int32 - s = func(1.) + s = func(1.0) assert s.dtype == np.float64 - if op != 'pow': - s = func(1.j) + if op != "pow": + s = func(1.0j) assert s.dtype == np.complex128 S = S.copy(dtype=np.float64) - for op in ['add', 'sub', 'mul', 'pow']: - func = getattr(S, f'__{op}__') + for op in ["add", "sub", "mul", "pow"]: + func = getattr(S, f"__{op}__") s = func(1) assert s.dtype == np.float64 - s = func(1.) + s = func(1.0) assert s.dtype == np.float64 - if op != 'pow': - s = func(1.j) + if op != "pow": + s = func(1.0j) assert s.dtype == np.complex128 S = S.copy(dtype=np.complex128) - for op in ['add', 'sub', 'mul', 'pow']: - func = getattr(S, f'__{op}__') + for op in ["add", "sub", "mul", "pow"]: + func = getattr(S, f"__{op}__") s = func(1) assert s.dtype == np.complex128 - s = func(1.) + s = func(1.0) assert s.dtype == np.complex128 - if op != 'pow': - s = func(1.j) + if op != "pow": + s = func(1.0j) assert s.dtype == np.complex128 @@ -1099,35 +1110,35 @@ def test_op4(): S = SparseCSR((10, 100), dtype=np.int32) # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) S[0, j] = i s = 1 + S assert s.dtype == np.int32 - s = 1. + S + s = 1.0 + S assert s.dtype == np.float64 - s = 1.j + S + s = 1.0j + S assert s.dtype == np.complex128 s = 1 - S assert s.dtype == np.int32 - s = 1. - S + s = 1.0 - S assert s.dtype == np.float64 - s = 1.j - S + s = 1.0j - S assert s.dtype == np.complex128 s = 1 * S assert s.dtype == np.int32 - s = 1. * S + s = 1.0 * S assert s.dtype == np.float64 - s = 1.j * S + s = 1.0j * S assert s.dtype == np.complex128 - s = 1 ** S + s = 1**S assert s.dtype == np.int32 - s = 1. ** S + s = 1.0**S assert s.dtype == np.float64 - s = 1.j ** S + s = 1.0j**S assert s.dtype == np.complex128 @@ -1137,7 +1148,7 @@ def test_op5(): S3 = SparseCSR((10, 100), dtype=np.int32) # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) S1[0, j] = i S2[0, j] = i @@ -1166,7 +1177,7 @@ def test_op5(): S //= 2 assert np.allclose(S.todense(), S1.todense()) - S = S1 / 2. + S = S1 / 2.0 S *= 2 assert np.allclose(S.todense(), S1.todense()) @@ -1176,7 +1187,7 @@ def test_op_numpy_scalar(): I = np.ones(1, dtype=np.complex64)[0] # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) S[0, j] = i S.finalize() @@ -1212,12 +1223,12 @@ def test_op_numpy_scalar(): assert s.dtype == np.complex64 assert s._D.sum() == Ssum - s = S ** I + s = S**I assert isinstance(s, SparseCSR) assert s.dtype == np.complex64 assert s._D.sum() == Ssum - s = I ** S + s = I**S assert isinstance(s, SparseCSR) assert s.dtype == np.complex64 @@ -1232,7 +1243,7 @@ def test_op_sparse_dim(): I = np.ones(1, dtype=np.complex64)[0] # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) S[0, j] = i S.finalize() @@ -1258,7 +1269,7 @@ def test_sparse_transpose(): I = np.ones(1, dtype=np.complex64)[0] # Create initial stuff for i in range(10): - j = range(i*4, i*4+3) + j = range(i * 4, i * 4 + 3) S[0, j] = i S.finalize() @@ -1303,17 +1314,15 @@ def test_op_reduce(): def test_unfinalized_math(): S1 = SparseCSR((4, 4, 1)) S2 = SparseCSR((4, 4, 1)) - S1[0, 0] = 2. - S1[1, 2] = 3. - S2[2, 3] = 4. - S2[2, 2] = 3. + S1[0, 0] = 2.0 + S1[1, 2] = 3.0 + S2[2, 3] = 4.0 + S2[2, 2] = 3.0 S2[0, 0] = 4 for i in range(3): - assert np.allclose(S1.todense() + S2.todense(), - (S1 + S2).todense()) - assert np.allclose(S1.todense() * S2.todense(), - (S1 * S2).todense()) + assert np.allclose(S1.todense() + S2.todense(), (S1 + S2).todense()) + assert np.allclose(S1.todense() * S2.todense(), (S1 * S2).todense()) sin = np.sin(S1.todense()) + np.sin(S2.todense()) sins = (np.sin(S1) + np.sin(S2)).todense() assert np.allclose(sin, sins) @@ -1335,6 +1344,7 @@ def test_unfinalized_math(): def test_pickle(): import pickle as p + S = SparseCSR((10, 10, 2), dtype=np.int32) S[0, 0] = [1, 2] S[2, 0] = [1, 2] @@ -1357,6 +1367,7 @@ def test_sparse_column_out_of_bounds(j): with pytest.raises(IndexError): S[0, j] = 1 + def test_fromsp_csr(): csr1 = sc.sparse.random(10, 100, 0.01, random_state=24812) csr2 = sc.sparse.random(10, 100, 0.02, random_state=24813) @@ -1365,8 +1376,8 @@ def test_fromsp_csr(): csr_1 = csr.tocsr(0) csr_2 = csr.tocsr(1) - assert np.abs(csr1 - csr_1).sum() == 0. - assert np.abs(csr2 - csr_2).sum() == 0. + assert np.abs(csr1 - csr_1).sum() == 0.0 + assert np.abs(csr2 - csr_2).sum() == 0.0 def test_transform1(): @@ -1380,7 +1391,7 @@ def test_transform1(): assert tr.shape[:2] == csr.shape[:2] assert tr.shape[2] == len(matrix) - assert np.abs(tr.tocsr(0) - 0.3 * csr1 - 0.7 * csr2).sum() == 0. + assert np.abs(tr.tocsr(0) - 0.3 * csr1 - 0.7 * csr2).sum() == 0.0 def test_transform2(): @@ -1394,8 +1405,8 @@ def test_transform2(): assert tr.shape[:2] == csr.shape[:2] assert tr.shape[2] == len(matrix) - assert np.abs(tr.tocsr(0) - 0.3 * csr1).sum() == 0. - assert np.abs(tr.tocsr(1) - 0.7 * csr2).sum() == 0. + assert np.abs(tr.tocsr(0) - 0.3 * csr1).sum() == 0.0 + assert np.abs(tr.tocsr(1) - 0.7 * csr2).sum() == 0.0 def test_transform3(): @@ -1409,9 +1420,9 @@ def test_transform3(): assert tr.shape[:2] == csr.shape[:2] assert tr.shape[2] == len(matrix) - assert np.abs(tr.tocsr(0) - 0.3 * csr1).sum() == 0. - assert np.abs(tr.tocsr(1) - 0.7 * csr2).sum() == 0. - assert np.abs(tr.tocsr(2) - 0.1 * csr1 - 0.2 * csr2).sum() == 0. + assert np.abs(tr.tocsr(0) - 0.3 * csr1).sum() == 0.0 + assert np.abs(tr.tocsr(1) - 0.7 * csr2).sum() == 0.0 + assert np.abs(tr.tocsr(2) - 0.1 * csr1 - 0.2 * csr2).sum() == 0.0 def test_transform4(): @@ -1425,7 +1436,7 @@ def test_transform4(): assert tr.shape[:2] == csr.shape[:2] assert tr.shape[2] == len(matrix) - assert np.abs(tr.tocsr(0) - 0.3j * csr1 - 0.7j * csr2).sum() == 0. + assert np.abs(tr.tocsr(0) - 0.3j * csr1 - 0.7j * csr2).sum() == 0.0 def test_transform_fail(): @@ -1434,7 +1445,7 @@ def test_transform_fail(): csr = SparseCSR.fromsp(csr1, csr2) # complex 1x3 matrix - matrix = [[0.3j, 0.7j, 1.]] + matrix = [[0.3j, 0.7j, 1.0]] with pytest.raises(ValueError): csr.transform(matrix=matrix) @@ -1453,7 +1464,7 @@ def test_fromsp_csr_large(): indices = row.indices if len(indices) == 0: indices = np.arange(3) - csr2[9948, (indices + 1) % 10] = 1. + csr2[9948, (indices + 1) % 10] = 1.0 assert csr1.getnnz() != csr2.getnnz() t0 = time() @@ -1463,10 +1474,10 @@ def test_fromsp_csr_large(): csr_1 = csr.tocsr(0) csr_2 = csr.tocsr(1) - assert np.abs(csr1 - csr_1).sum() == 0. - assert np.abs(csr2 - csr_2).sum() == 0. + assert np.abs(csr1 - csr_1).sum() == 0.0 + assert np.abs(csr2 - csr_2).sum() == 0.0 - csr_ = SparseCSR(csr1.shape + (2, ), nnzpr=1) + csr_ = SparseCSR(csr1.shape + (2,), nnzpr=1) t0 = time() for ic, c in enumerate([csr1, csr2]): @@ -1474,17 +1485,17 @@ def test_fromsp_csr_large(): # Loop stuff for r in range(c.shape[0]): - idx = csr_._extend(r, c.indices[ptr[r]:ptr[r+1]]) - csr_._D[idx, ic] += c.data[ptr[r]:ptr[r+1]] + idx = csr_._extend(r, c.indices[ptr[r] : ptr[r + 1]]) + csr_._D[idx, ic] += c.data[ptr[r] : ptr[r + 1]] if print_time: print(f"timing: 2 x ptr[]:ptr[] {time() - t0}") dcsr = csr - csr_ - assert np.abs(dcsr.tocsr(0)).sum() == 0. - assert np.abs(dcsr.tocsr(1)).sum() == 0. + assert np.abs(dcsr.tocsr(0)).sum() == 0.0 + assert np.abs(dcsr.tocsr(1)).sum() == 0.0 - csr_ = SparseCSR(csr1.shape + (2, ), nnzpr=1) + csr_ = SparseCSR(csr1.shape + (2,), nnzpr=1) t0 = time() for ic, c in enumerate([csr1, csr2]): @@ -1492,7 +1503,7 @@ def test_fromsp_csr_large(): # Loop stuff for r in range(c.shape[0]): - sl = slice(ptr[r], ptr[r+1]) + sl = slice(ptr[r], ptr[r + 1]) idx = csr_._extend(r, c.indices[sl]) csr_._D[idx, ic] += c.data[sl] if print_time: diff --git a/src/sisl/tests/test_sparse_geometry.py b/src/sisl/tests/test_sparse_geometry.py index 321019e93f..7c0c0626aa 100644 --- a/src/sisl/tests/test_sparse_geometry.py +++ b/src/sisl/tests/test_sparse_geometry.py @@ -16,16 +16,16 @@ @pytest.fixture def setup(): - class t(): + class t: def __init__(self): - self.g = fcc(1., Atom(1, R=1.495)) * 2 + self.g = fcc(1.0, Atom(1, R=1.495)) * 2 self.s1 = SparseAtom(self.g) self.s2 = SparseAtom(self.g, 2) + return t() class TestSparseAtom: - def test_fail_align1(self, setup): s = SparseAtom(setup.g * 2) str(s) @@ -133,7 +133,7 @@ def test_untile1(self, setup): def test_untile_wrong_usage(self): # one should not untile - geometry = Geometry([0] * 3, Atom(1, R=1.001), Lattice(1, nsc=[1]* 3)) + geometry = Geometry([0] * 3, Atom(1, R=1.001), Lattice(1, nsc=[1] * 3)) geometry = geometry.tile(4, 0) s = SparseAtom(geometry) s.construct([[0.1, 1.01], [1, 2]]) @@ -149,7 +149,7 @@ def test_untile_wrong_usage(self): def test_untile_segment_single(self): # one should not untile - geometry = Geometry([0] * 3, Atom(1, R=1.001), Lattice(1, nsc=[1]* 3)) + geometry = Geometry([0] * 3, Atom(1, R=1.001), Lattice(1, nsc=[1] * 3)) geometry = geometry.tile(4, 0) s = SparseAtom(geometry) s.construct([[0.1, 1.01], [1, 2]]) @@ -162,7 +162,7 @@ def test_untile_segment_single(self): sx = s.untile(4, 0, segment=seg).tile(2, 0) ds = s4 - sx ds.finalize() - assert np.absolute(ds)._csr._D.sum() == pytest.approx(0.) + assert np.absolute(ds)._csr._D.sum() == pytest.approx(0.0) @pytest.mark.parametrize("axis", [0, 1, 2]) def test_untile_segment_three(self, axis): @@ -182,26 +182,26 @@ def test_untile_segment_three(self, axis): sx = s.untile(4, axis, segment=seg).tile(2, axis) ds = s4 - sx ds.finalize() - assert np.absolute(ds)._csr._D.sum() == pytest.approx(0.) + assert np.absolute(ds)._csr._D.sum() == pytest.approx(0.0) def test_unrepeat_setup(self, setup): s1 = SparseAtom(setup.g) s1.construct([[0.1, 1.5], [1, 2]]) - s2 = SparseAtom(setup.g * ((2, 2, 2), 'r')) + s2 = SparseAtom(setup.g * ((2, 2, 2), "r")) s2.construct([[0.1, 1.5], [1, 2]]) s2 = s2.unrepeat(2, 2).unrepeat(2, 1).unrepeat(2, 0) assert s1.spsame(s2) s1 = SparseAtom(setup.g) s1.construct([[0.1, 1.5], [1, 2]]) - s2 = SparseAtom(setup.g * ([2, 1, 1], 'r')) + s2 = SparseAtom(setup.g * ([2, 1, 1], "r")) s2.construct([[0.1, 1.5], [1, 2]]) s2 = s2.unrepeat(2, 0) assert s1.spsame(s2) s1 = SparseAtom(setup.g) s1.construct([[0.1, 1.5], [1, 2]]) - s2 = SparseAtom(setup.g * ([1, 2, 1], 'r')) + s2 = SparseAtom(setup.g * ([1, 2, 1], "r")) s2.construct([[0.1, 1.5], [1, 2]]) s2 = s2.unrepeat(2, 1) assert s1.spsame(s2) @@ -218,7 +218,7 @@ def test_rij_fail1(self, setup): s = SparseOrbital(setup.g.copy()) s.construct([[0.1, 1.5], [1, 2]]) with pytest.raises(ValueError): - s.rij(what='none') + s.rij(what="none") def test_rij_atom(self, setup): s = SparseAtom(setup.g.copy()) @@ -235,7 +235,7 @@ def test_rij_atom_orbital_compare(self, setup): orbital = so.rij() assert atom.spsame(orbital) atom = sa.rij() - orbital = so.rij('atom') + orbital = so.rij("atom") assert atom.spsame(orbital) # This only works because there is 1 orbital per atom orbital = so.rij() @@ -254,10 +254,10 @@ def test_sp_orb_remove(self, setup): assert so.geometry.na - 1 == so2.geometry.na def test_sp_orb_remove_atom(self): - so = SparseOrbital(Geometry([[0] *3, [1]* 3], [Atom[1], Atom[2]], 2)) + so = SparseOrbital(Geometry([[0] * 3, [1] * 3], [Atom[1], Atom[2]], 2)) so2 = so.remove(Atom[1]) assert so.geometry.na - 1 == so2.geometry.na - assert so.geometry.no -1 == so2.geometry.no + assert so.geometry.no - 1 == so2.geometry.no def test_remove1(self, setup): for i in range(len(setup.g)): @@ -329,7 +329,7 @@ def test_repeat1(self, setup): setup.s1.construct([[0.1, 1.5], [1, 2]]) s1 = setup.s1.repeat(2, 0).repeat(2, 1) setup.s1.empty() - s2 = SparseAtom(setup.g * ([2, 2, 1], 'r')) + s2 = SparseAtom(setup.g * ([2, 2, 1], "r")) s2.construct([[0.1, 1.5], [1, 2]]) assert s1.spsame(s2) s1.finalize() @@ -341,7 +341,7 @@ def test_repeat2(self, setup): setup.s1.finalize() s1 = setup.s1.repeat(2, 0).repeat(2, 1) setup.s1.empty() - s2 = SparseAtom(setup.g * ([2, 2, 1], 'r')) + s2 = SparseAtom(setup.g * ([2, 2, 1], "r")) s2.construct([[0.1, 1.5], [1, 2]]) assert s1.spsame(s2) s1.finalize() @@ -363,7 +363,7 @@ def test_supercell_poisition1(self, setup): assert s1.spsame(s2) def test_set_nsc1(self, setup): - g = fcc(1., Atom(1, R=3.5)) + g = fcc(1.0, Atom(1, R=3.5)) s = SparseAtom(g) s.construct([[0.1, 1.5, 3.5], [1, 2, 3]]) s.finalize() @@ -421,7 +421,7 @@ def test_op_numpy_iscalar(self, setup): I = np.float32(1) # Create initial stuff for i in range(10): - j = range(i, i*2) + j = range(i, i * 2) S[0, j] = i S.finalize() @@ -453,7 +453,7 @@ def test_op_numpy_scalar(self, setup): I = np.ones(1, dtype=np.complex128)[0] # Create initial stuff for i in range(10): - j = range(i, i*2) + j = range(i, i * 2) S[0, j] = i S.finalize() @@ -489,12 +489,12 @@ def test_op_numpy_scalar(self, setup): assert s.dtype == np.complex128 assert s._csr._D.sum() == Ssum - s = S ** I + s = S**I assert isinstance(s, SparseAtom) assert s.dtype == np.complex128 assert s._csr._D.sum() == Ssum - s = I ** S + s = I**S assert isinstance(s, SparseAtom) assert s.dtype == np.complex128 @@ -504,7 +504,7 @@ def test_numpy_reduction(self, setup): I = np.ones(1, dtype=np.complex128)[0] # Create initial stuff for i in range(2): - j = range(i, i+2) + j = range(i, i + 2) S[i, j] = 1 S.finalize() assert np.sum(S, axis=(0, 1)) == pytest.approx(1 * 2 * 2) @@ -519,7 +519,7 @@ def test_fromsp1(self, setup): assert np.allclose(s1.shape, [g.na, g.na_s, 1]) assert np.allclose(s1[0, [1, 2, 3]], np.ones([3], np.int32)) - assert np.allclose(s1[1, [1, 2, 4]], np.ones([3], np.int32)*2) + assert np.allclose(s1[1, [1, 2, 4]], np.ones([3], np.int32) * 2) # Different instantiating s2 = SparseAtom.fromsp(g, lil) @@ -538,7 +538,7 @@ def test_fromsp2(self, setup): assert np.allclose(s1[0, [1, 2, 3], 0], np.ones([3], np.int32)) assert np.allclose(s1[0, [1, 2, 3], 1], np.zeros([3], np.int32)) assert np.allclose(s1[1, [1, 2, 4], 0], np.zeros([3], np.int32)) - assert np.allclose(s1[1, [1, 2, 4], 1], np.ones([3], np.int32)*2) + assert np.allclose(s1[1, [1, 2, 4], 1], np.ones([3], np.int32) * 2) def test_fromsp4(self, setup): g = setup.g.repeat(2, 0).tile(2, 1) @@ -569,7 +569,7 @@ def test_pickle(self, setup): @pytest.mark.parametrize("n1", [1, 3]) @pytest.mark.parametrize("n2", [1, 4]) def test_sparse_atom_symmetric(n0, n1, n2): - g = fcc(1., Atom(1, R=1.5)) * 2 + g = fcc(1.0, Atom(1, R=1.5)) * 2 s = SparseAtom(g) s.construct([[0.1, 1.51], [1, 2]]) s = s.tile(n0, 0).tile(n1, 1).tile(n2, 2) @@ -580,7 +580,7 @@ def test_sparse_atom_symmetric(n0, n1, n2): # orbitals connecting to ia edges = s.edges(ia) # Figure out the transposed supercell indices of the edges - isc = - s.geometry.a2isc(edges) + isc = -s.geometry.a2isc(edges) # Convert to supercell IA = s.geometry.lattice.sc_index(isc) * na + ia # Figure out if 'ia' is also in the back-edges @@ -594,64 +594,64 @@ def test_sparse_atom_symmetric(n0, n1, n2): @pytest.mark.parametrize("i", [0, 1, 2]) def test_sparse_atom_transpose_single(i): - """ This is problematic when the sparsity pattern is not *filled* """ - g = fcc(1., Atom(1, R=1.5)) * 3 + """This is problematic when the sparsity pattern is not *filled*""" + g = fcc(1.0, Atom(1, R=1.5)) * 3 s = SparseAtom(g) - s[i, 2] = 1. - s[i, 0] = 2. + s[i, 2] = 1.0 + s[i, 0] = 2.0 t = s.transpose() assert t.nnz == s.nnz - assert t[2, i] == pytest.approx(1.) - assert t[0, i] == pytest.approx(2.) + assert t[2, i] == pytest.approx(1.0) + assert t[0, i] == pytest.approx(2.0) @pytest.mark.parametrize("i", [0, 1, 2]) def test_sparse_atom_transpose_more(i): - """ This is problematic when the sparsity pattern is not *filled* """ - g = fcc(1., Atom(1, R=1.5)) * 3 + """This is problematic when the sparsity pattern is not *filled*""" + g = fcc(1.0, Atom(1, R=1.5)) * 3 s = SparseAtom(g) - s[i, 2] = 1. - s[i, 0] = 2. - s[i + 2, 3] = 1. - s[i + 2, 5] = 2. + s[i, 2] = 1.0 + s[i, 0] = 2.0 + s[i + 2, 3] = 1.0 + s[i + 2, 5] = 2.0 t = s.transpose() assert t.nnz == s.nnz - assert t[2, i] == pytest.approx(1.) - assert t[0, i] == pytest.approx(2.) - assert t[3, i + 2] == pytest.approx(1.) - assert t[5, i + 2] == pytest.approx(2.) + assert t[2, i] == pytest.approx(1.0) + assert t[0, i] == pytest.approx(2.0) + assert t[3, i + 2] == pytest.approx(1.0) + assert t[5, i + 2] == pytest.approx(2.0) @pytest.mark.parametrize("i", [0, 1, 2]) def test_sparse_orbital_transpose_single(i): - g = fcc(1., Atom(1, R=(1.5, 2.1))) * 3 + g = fcc(1.0, Atom(1, R=(1.5, 2.1))) * 3 s = SparseOrbital(g) - s[i, 2] = 1. - s[i, 0] = 2. + s[i, 2] = 1.0 + s[i, 0] = 2.0 t = s.transpose() assert t.nnz == s.nnz - assert t[2, i] == pytest.approx(1.) - assert t[0, i] == pytest.approx(2.) + assert t[2, i] == pytest.approx(1.0) + assert t[0, i] == pytest.approx(2.0) @pytest.mark.parametrize("i", [0, 1, 2]) def test_sparse_orbital_transpose_more(i): - g = fcc(1., Atom(1, R=(1.5, 2.1))) * 3 + g = fcc(1.0, Atom(1, R=(1.5, 2.1))) * 3 s = SparseOrbital(g) - s[i, 2] = 1. - s[i, 0] = 2. - s[i + 3, 4] = 1. - s[i + 2, 4] = 2. + s[i, 2] = 1.0 + s[i, 0] = 2.0 + s[i + 3, 4] = 1.0 + s[i + 2, 4] = 2.0 t = s.transpose() assert t.nnz == s.nnz - assert t[2, i] == pytest.approx(1.) - assert t[0, i] == pytest.approx(2.) - assert t[4, i + 3] == pytest.approx(1.) - assert t[4, i + 2] == pytest.approx(2.) + assert t[2, i] == pytest.approx(1.0) + assert t[0, i] == pytest.approx(2.0) + assert t[4, i + 3] == pytest.approx(1.0) + assert t[4, i + 2] == pytest.approx(2.0) def test_sparse_orbital_add_axis(setup): @@ -666,7 +666,8 @@ def test_sparse_orbital_add_axis(setup): def test_sparse_orbital_add_no_axis(): from sisl.geom import sc - g = (sc(1., Atom(1, R=1.5)) * 2).add(Lattice([0, 0, 5])) + + g = (sc(1.0, Atom(1, R=1.5)) * 2).add(Lattice([0, 0, 5])) s = SparseOrbital(g) s.construct([[0.1, 1.5], [1, 2]]) s1 = s.add(s, offset=[0, 0, 3]) @@ -677,7 +678,7 @@ def test_sparse_orbital_add_no_axis(): def test_sparse_orbital_sub_orbital(): atom = Atom(1, (1, 2, 3)) - g = fcc(1., atom) * 2 + g = fcc(1.0, atom) * 2 s = SparseOrbital(g) # take out some orbitals @@ -703,12 +704,12 @@ def test_translate_sparse_atoms(): # Build a dummy matrix with onsite terms and just one coupling term matrix = SparseAtom(graph) - matrix[0,0] = 1 - matrix[1,1] = 2 - matrix[0,1] = 3 + matrix[0, 0] = 1 + matrix[1, 1] = 2 + matrix[0, 1] = 3 # Translate the second atom - transl = matrix._translate_atoms_sc([[0,0,0], [1, 0, 0]]) + transl = matrix._translate_atoms_sc([[0, 0, 0], [1, 0, 0]]) # Check that the new auxiliary cell is correct. assert np.allclose(transl.nsc, [3, 1, 1]) @@ -719,10 +720,10 @@ def test_translate_sparse_atoms(): assert np.allclose(transl.geometry[1], matrix.geometry[1] + matrix.geometry.cell[0]) # Assert that the matrix elements have been translated - assert transl[0,0] == 1 - assert transl[1,1] == 2 - assert transl[0,1] == 0 - assert transl[0,3] == 3 + assert transl[0, 0] == 1 + assert transl[1, 1] == 2 + assert transl[0, 1] == 0 + assert transl[0, 3] == 3 # Translate back to unit cell uc_matrix = transl.translate2uc() @@ -733,13 +734,13 @@ def test_translate_sparse_atoms(): assert np.allclose(uc_matrix.geometry.xyz, matrix.geometry.xyz) - assert uc_matrix[0,0] == 1 - assert uc_matrix[1,1] == 2 - assert uc_matrix[0,1] == 3 - assert uc_matrix[0,3] == 0 + assert uc_matrix[0, 0] == 1 + assert uc_matrix[1, 1] == 2 + assert uc_matrix[0, 1] == 3 + assert uc_matrix[0, 3] == 0 # Instead, test atoms and axes arguments to avoid any translation. - for kwargs in [{'atoms': [0]}, {'axes': [1, 2]}]: + for kwargs in [{"atoms": [0]}, {"axes": [1, 2]}]: not_uc_matrix = transl.translate2uc(**kwargs) # Check the auxiliary cell, coordinates and the matrix elements @@ -751,16 +752,20 @@ def test_translate_sparse_atoms(): assert np.allclose(not_uc_matrix._csr.todense(), transl._csr.todense()) # Now, translate both atoms - transl_both = uc_matrix._translate_atoms_sc([[-1,0,0], [1, 0, 0]]) + transl_both = uc_matrix._translate_atoms_sc([[-1, 0, 0], [1, 0, 0]]) # Check the auxiliary cell, coordinates and the matrix elements assert np.allclose(transl_both.nsc, [5, 1, 1]) assert np.allclose(transl_both.shape, [2, 10, 1]) - assert np.allclose(transl_both.geometry[0], uc_matrix.geometry[0] - uc_matrix.geometry.cell[0]) - assert np.allclose(transl_both.geometry[1], uc_matrix.geometry[1] + uc_matrix.geometry.cell[0]) - - assert transl_both[0,0] == 1 - assert transl_both[1,1] == 2 - assert transl_both[0,1] == 0 - assert transl_both[0,3] == 3 + assert np.allclose( + transl_both.geometry[0], uc_matrix.geometry[0] - uc_matrix.geometry.cell[0] + ) + assert np.allclose( + transl_both.geometry[1], uc_matrix.geometry[1] + uc_matrix.geometry.cell[0] + ) + + assert transl_both[0, 0] == 1 + assert transl_both[1, 1] == 2 + assert transl_both[0, 1] == 0 + assert transl_both[0, 3] == 3 diff --git a/src/sisl/tests/test_sparse_orbital.py b/src/sisl/tests/test_sparse_orbital.py index 32b249c3c0..7592df9c2c 100644 --- a/src/sisl/tests/test_sparse_orbital.py +++ b/src/sisl/tests/test_sparse_orbital.py @@ -12,16 +12,18 @@ from sisl.geom import fcc, graphene from sisl.sparse_geometry import * -pytestmark = [pytest.mark.sparse, - pytest.mark.sparse_geometry, - pytest.mark.sparse_orbital] +pytestmark = [ + pytest.mark.sparse, + pytest.mark.sparse_geometry, + pytest.mark.sparse_orbital, +] @pytest.mark.parametrize("n0", [1, 3]) @pytest.mark.parametrize("n1", [1, 4]) @pytest.mark.parametrize("n2", [1, 2]) def test_sparse_orbital_symmetric(n0, n1, n2): - g = fcc(1., Atom(1, R=1.5)) * 2 + g = fcc(1.0, Atom(1, R=1.5)) * 2 s = SparseOrbital(g) s.construct([[0.1, 1.51], [1, 2]]) s = s.tile(n0, 0).tile(n1, 1).tile(n2, 2) @@ -32,7 +34,7 @@ def test_sparse_orbital_symmetric(n0, n1, n2): # orbitals connecting to io edges = s.edges(io) # Figure out the transposed supercell indices of the edges - isc = - s.geometry.o2isc(edges) + isc = -s.geometry.o2isc(edges) # Convert to supercell IO = s.geometry.lattice.sc_index(isc) * no + io # Figure out if 'io' is also in the back-edges @@ -49,7 +51,7 @@ def test_sparse_orbital_symmetric(n0, n1, n2): @pytest.mark.parametrize("n2", [1, 2]) @pytest.mark.parametrize("axis", [0, 1]) def test_sparse_orbital_append(n0, n1, n2, axis): - g = fcc(1., Atom(1, R=1.98)) * 2 + g = fcc(1.0, Atom(1, R=1.98)) * 2 dists = np.insert(g.distance(0, R=g.maxR()) + 0.001, 0, 0.001) connect = np.arange(dists.size, dtype=np.float64) / 5 s = SparseOrbital(g) @@ -88,7 +90,7 @@ def test_sparse_orbital_append(n0, n1, n2, axis): @pytest.mark.parametrize("n2", [1, 2]) @pytest.mark.parametrize("axis", [1, 2]) def test_sparse_orbital_append_scale(n0, n1, n2, axis): - g = fcc(1., Atom(1, R=1.98)) * 2 + g = fcc(1.0, Atom(1, R=1.98)) * 2 dists = np.insert(g.distance(0, R=g.maxR()) + 0.001, 0, 0.001) connect = np.arange(dists.size, dtype=np.float64) / 5 s = SparseOrbital(g) @@ -116,13 +118,13 @@ def test_sparse_orbital_append_scale(n0, n1, n2, axis): s = sf.sub(np.concatenate([idx1, s1.na + idx2])) s.finalize() - sout = s1.sub(idx1).append(s2.sub(idx2), axis, scale=(2., 0)) + sout = s1.sub(idx1).append(s2.sub(idx2), axis, scale=(2.0, 0)) sout = (sout + sout.transpose()) * 0.5 assert sout.spsame(s) sout.finalize() assert np.allclose(s._csr._D, sout._csr._D) - sout = s1.sub(idx1).append(s2.sub(idx2), axis, scale=(0., 2.)) + sout = s1.sub(idx1).append(s2.sub(idx2), axis, scale=(0.0, 2.0)) sout.finalize() # Ensure that some elements are not the same! assert not np.allclose(s._csr._D, sout._csr._D) @@ -137,10 +139,10 @@ def test_sparse_orbital_hermitian(): for f in [True, False]: spo = SparseOrbital(g) - spo[0, 0] = 1. + spo[0, 0] = 1.0 # Create only a single coupling to the neighouring element - spo[0, 1] = 2. + spo[0, 1] = 2.0 if f: spo.finalize() @@ -151,16 +153,16 @@ def test_sparse_orbital_hermitian(): spoT = spo.transpose() assert spoT.finalized assert spoT.nnz == 2 - assert spoT[0, 0] == 1. - assert spoT[0, 1] == 0. - assert spoT[0, 2] == 2. + assert spoT[0, 0] == 1.0 + assert spoT[0, 1] == 0.0 + assert spoT[0, 2] == 2.0 spoH = (spo + spoT) * 0.5 assert spoH.nnz == 3 - assert spoH[0, 0] == 1. - assert spoH[0, 1] == 1. - assert spoH[0, 2] == 1. + assert spoH[0, 0] == 1.0 + assert spoH[0, 1] == 1.0 + assert spoH[0, 2] == 1.0 def test_sparse_orbital_sub_orbital(): @@ -179,12 +181,7 @@ def test_sparse_orbital_sub_orbital(): spo = spo + spo.transpose() # orbitals on first atom (a0) - rem_sub = [ - (0, [1, 2]), - ([0, 2], 1), - (2, [0, 1]), - (a0[0], [1, 2]) - ] + rem_sub = [(0, [1, 2]), ([0, 2], 1), (2, [0, 1]), (a0[0], [1, 2])] for rem, sub in rem_sub: spo_rem = spo.remove_orbital(0, rem) spo_sub = spo.sub_orbital(0, sub) @@ -245,7 +242,7 @@ def test_sparse_orbital_sub_orbital_nested(): Doing nested or multiple subs that exposes the same sub-atom should ultimately re-use the existing atom. - However, due to the renaming of the tags + However, due to the renaming of the tags """ a0 = Atom(1, R=(1.1, 1.4, 1.6)) a1 = Atom(2, R=(1.3, 1.1)) @@ -291,29 +288,28 @@ def test_sparse_orbital_replace_simple(): # now replace every position that can be replaced for position in range(0, spo44.geometry.na, 2): - # replace both atoms - new = spo44.replace([position, position+1], spo) - assert np.fabs((new - spo44)._csr._D).sum() == 0. + new = spo44.replace([position, position + 1], spo) + assert np.fabs((new - spo44)._csr._D).sum() == 0.0 # replace both atoms (reversed) # When swapping it is not the same - new = spo44.replace([position, position+1], spo, [1, 0]) - assert np.fabs((new - spo44)._csr._D).sum() != 0. + new = spo44.replace([position, position + 1], spo, [1, 0]) + assert np.fabs((new - spo44)._csr._D).sum() != 0.0 pvt = np.arange(spo44.na) pvt[position] = position + 1 - pvt[position+1] = position + pvt[position + 1] = position assert np.unique(pvt).size == pvt.size new_pvt = spo44.sub(pvt) - assert np.fabs((new - new_pvt)._csr._D).sum() == 0. + assert np.fabs((new - new_pvt)._csr._D).sum() == 0.0 # replace first atom new = spo44.replace([position], spo, 0) - assert np.fabs((new - spo44)._csr._D).sum() == 0. + assert np.fabs((new - spo44)._csr._D).sum() == 0.0 # replace second atom - new = spo44.replace([position+1], spo, 1) - assert np.fabs((new - spo44)._csr._D).sum() == 0. + new = spo44.replace([position + 1], spo, 1) + assert np.fabs((new - spo44)._csr._D).sum() == 0.0 def test_sparse_orbital_replace_specie_warn(): @@ -338,7 +334,7 @@ def test_sparse_orbital_replace_specie_warn(): def test_sparse_orbital_replace_hole(): - """ Create a big graphene flake remove a hole (1 orbital system) """ + """Create a big graphene flake remove a hole (1 orbital system)""" g = graphene(orthogonal=True) spo = SparseOrbital(g) # create the sparse-orbital @@ -364,25 +360,29 @@ def create_sp(geom): assert len(atoms) == 4 * 6 * 6 new = big.replace(atoms, hole) new_copy = create_sp(new.geometry) - assert np.fabs((new - new_copy)._csr._D).sum() == 0. + assert np.fabs((new - new_copy)._csr._D).sum() == 0.0 def test_sparse_orbital_replace_hole_norbs(): - """ Create a big graphene flake remove a hole (multiple orbitals) """ + """Create a big graphene flake remove a hole (multiple orbitals)""" a1 = Atom(5, R=(1.44, 1.44)) a2 = Atom(7, R=(1.44, 1.44, 1.44)) g = graphene(atoms=[a1, a2], orthogonal=True) spo = SparseOrbital(g) + def func(self, ia, atoms, atoms_xyz=None): geom = self.geometry + def a2o(idx): return geom.a2o(idx, True) + io = a2o(ia) idx = self.geometry.close(ia, R=[0.1, 1.44], atoms=atoms, atoms_xyz=atoms_xyz) idx = list(map(a2o, idx)) self[io, idx[0]] = 0 for i in io: self[i, idx[1]] = 2.7 + # create the sparse-orbital spo.construct(func) @@ -406,7 +406,7 @@ def create_sp(geom): assert len(atoms) == 4 * 6 * 6 new = big.replace(atoms, hole) new_copy = create_sp(new.geometry) - assert np.fabs((new - new_copy)._csr._D).sum() == 0. + assert np.fabs((new - new_copy)._csr._D).sum() == 0.0 def test_translate_sparse_orbitals(): @@ -417,13 +417,13 @@ def test_translate_sparse_orbitals(): # Build a dummy matrix with onsite terms and just one coupling term matrix = SparseOrbital(graph) - matrix[0,0] = 1 - matrix[1,1] = 2 - matrix[0,1] = 3 - matrix[0,4] = 4 + matrix[0, 0] = 1 + matrix[1, 1] = 2 + matrix[0, 1] = 3 + matrix[0, 4] = 4 # Translate the second atom - transl = matrix._translate_atoms_sc([[0,0,0], [1, 0, 0]]) + transl = matrix._translate_atoms_sc([[0, 0, 0], [1, 0, 0]]) # Check that the new auxiliary cell is correct. assert np.allclose(transl.nsc, [3, 1, 1]) @@ -434,10 +434,10 @@ def test_translate_sparse_orbitals(): assert np.allclose(transl.geometry[1], matrix.geometry[1] + matrix.geometry.cell[0]) # Assert that the matrix elements have been translated - assert transl[0,0] == 1 - assert transl[1,1] == 2 - assert transl[0,1] == 3 - assert transl[0,6 + 4] == 4 + assert transl[0, 0] == 1 + assert transl[1, 1] == 2 + assert transl[0, 1] == 3 + assert transl[0, 6 + 4] == 4 # Translate back to unit cell uc_matrix = transl.translate2uc() @@ -448,13 +448,13 @@ def test_translate_sparse_orbitals(): assert np.allclose(uc_matrix.geometry.xyz, matrix.geometry.xyz) - assert uc_matrix[0,0] == 1 - assert uc_matrix[1,1] == 2 - assert uc_matrix[0,1] == 3 - assert uc_matrix[0,4] == 4 + assert uc_matrix[0, 0] == 1 + assert uc_matrix[1, 1] == 2 + assert uc_matrix[0, 1] == 3 + assert uc_matrix[0, 4] == 4 # Instead, test atoms and axes arguments to avoid any translation. - for kwargs in [{'atoms': [0]}, {'axes': [1, 2]}]: + for kwargs in [{"atoms": [0]}, {"axes": [1, 2]}]: not_uc_matrix = transl.translate2uc(**kwargs) # Check the auxiliary cell, coordinates and the matrix elements @@ -466,16 +466,20 @@ def test_translate_sparse_orbitals(): assert np.allclose(not_uc_matrix._csr.todense(), transl._csr.todense()) # Now, translate both atoms - transl_both = uc_matrix._translate_atoms_sc([[-1,0,0], [1, 0, 0]]) + transl_both = uc_matrix._translate_atoms_sc([[-1, 0, 0], [1, 0, 0]]) # Check the auxiliary cell, coordinates and the matrix elements assert np.allclose(transl_both.nsc, [5, 1, 1]) - assert np.allclose(transl_both.shape, [6, 6*5, 1]) - - assert np.allclose(transl_both.geometry[0], uc_matrix.geometry[0] - uc_matrix.geometry.cell[0]) - assert np.allclose(transl_both.geometry[1], uc_matrix.geometry[1] + uc_matrix.geometry.cell[0]) - - assert transl_both[0,0] == 1 - assert transl_both[1,1] == 2 - assert transl_both[0,1] == 3 - assert transl_both[0,6+4] == 4 + assert np.allclose(transl_both.shape, [6, 6 * 5, 1]) + + assert np.allclose( + transl_both.geometry[0], uc_matrix.geometry[0] - uc_matrix.geometry.cell[0] + ) + assert np.allclose( + transl_both.geometry[1], uc_matrix.geometry[1] + uc_matrix.geometry.cell[0] + ) + + assert transl_both[0, 0] == 1 + assert transl_both[1, 1] == 2 + assert transl_both[0, 1] == 3 + assert transl_both[0, 6 + 4] == 4 diff --git a/src/sisl/typing/_common.py b/src/sisl/typing/_common.py index 075e578894..3280f0ac5d 100644 --- a/src/sisl/typing/_common.py +++ b/src/sisl/typing/_common.py @@ -15,7 +15,9 @@ # An atoms like argument that may be parsed by Geometry._sanitize_atoms AtomsArgument = Union[ npt.NDArray[Union[np.int_, np.bool_]], - str, int, dict, + str, + int, + dict, Atom, AtomCategory, GenericCategory, @@ -24,7 +26,9 @@ OrbitalsArgument = Union[ npt.NDArray[Union[np.int_, np.bool_]], - str, int, dict, + str, + int, + dict, AtomCategory, Shape, ] diff --git a/src/sisl/typing/tests/test_typing.py b/src/sisl/typing/tests/test_typing.py index 135880273e..a5e6418af8 100644 --- a/src/sisl/typing/tests/test_typing.py +++ b/src/sisl/typing/tests/test_typing.py @@ -10,8 +10,7 @@ pytestmark = pytest.mark.typing -def test_argument(): +def test_argument(): def func(a: st.AtomsArgument): str(a) - diff --git a/src/sisl/unit/base.py b/src/sisl/unit/base.py index 1b0079b7eb..d0acfb7c64 100644 --- a/src/sisl/unit/base.py +++ b/src/sisl/unit/base.py @@ -15,49 +15,54 @@ # This is the CODATA-2018 units unit_table = { - 'mass': {'DEFAULT': 'amu', - 'kg': 1.0, - 'g': 0.001, - 'amu': 1.6605390666e-27}, - 'length': {'DEFAULT': 'Ang', - 'm': 1.0, - 'cm': 0.01, - 'nm': 1e-09, - 'Ang': 1e-10, - 'pm': 1e-12, - 'fm': 1e-15, - 'Bohr': 5.29177210903e-11}, - 'time': {'DEFAULT': 'fs', - 's': 1.0, - 'ns': 1e-09, - 'ps': 1e-12, - 'fs': 1e-15, - 'min': 60.0, - 'hour': 3600.0, - 'day': 86400.0, - 'atu': 2.4188843265857e-17}, - 'energy': {'DEFAULT': 'eV', - 'J': 1.0, - 'kJ': 1.e3, - 'erg': 1e-07, - 'K': 1.380649e-23, - 'eV': 1.602176634e-19, - 'meV': 1.6021766339999998e-22, - 'Ha': 4.3597447222071e-18, - 'mHa': 4.3597447222071e-21, - 'Ry': 2.1798723611035e-18, - 'mRy': 2.1798723611035e-21}, - 'force': {'DEFAULT': 'eV/Ang', - 'N': 1.0, - 'eV/Ang': 1.6021766339999998e-09, - 'Ry/Bohr': 4.1193617491269446e-08, - 'Ha/Bohr': 8.238723498254079e-08} + "mass": {"DEFAULT": "amu", "kg": 1.0, "g": 0.001, "amu": 1.6605390666e-27}, + "length": { + "DEFAULT": "Ang", + "m": 1.0, + "cm": 0.01, + "nm": 1e-09, + "Ang": 1e-10, + "pm": 1e-12, + "fm": 1e-15, + "Bohr": 5.29177210903e-11, + }, + "time": { + "DEFAULT": "fs", + "s": 1.0, + "ns": 1e-09, + "ps": 1e-12, + "fs": 1e-15, + "min": 60.0, + "hour": 3600.0, + "day": 86400.0, + "atu": 2.4188843265857e-17, + }, + "energy": { + "DEFAULT": "eV", + "J": 1.0, + "kJ": 1.0e3, + "erg": 1e-07, + "K": 1.380649e-23, + "eV": 1.602176634e-19, + "meV": 1.6021766339999998e-22, + "Ha": 4.3597447222071e-18, + "mHa": 4.3597447222071e-21, + "Ry": 2.1798723611035e-18, + "mRy": 2.1798723611035e-21, + }, + "force": { + "DEFAULT": "eV/Ang", + "N": 1.0, + "eV/Ang": 1.6021766339999998e-09, + "Ry/Bohr": 4.1193617491269446e-08, + "Ha/Bohr": 8.238723498254079e-08, + }, } @set_module("sisl.unit") def unit_group(unit, tbl=unit_table): - """ The group of units that `unit` belong to + """The group of units that `unit` belong to Parameters ---------- @@ -76,12 +81,12 @@ def unit_group(unit, tbl=unit_table): for k in tbl: if unit in tbl[k]: return k - raise ValueError(f"The unit ""{unit!s}"" could not be located in the table.") + raise ValueError(f"The unit " "{unit!s}" " could not be located in the table.") @set_module("sisl.unit") def unit_default(group, tbl=unit_table): - """ The default unit of the unit group `group`. + """The default unit of the unit group `group`. Parameters ---------- @@ -104,7 +109,7 @@ def unit_default(group, tbl=unit_table): @set_module("sisl.unit") def unit_convert(fr, to, opts=None, tbl=unit_table): - """ Factor that takes `fr` to the units of `to` + """Factor that takes `fr` to the units of `to` Parameters ---------- @@ -146,7 +151,9 @@ def unit_convert(fr, to, opts=None, tbl=unit_table): toU = k toV = tbl[k][to] if frU != toU: - raise ValueError(f"The unit conversion is not from the same group: {frU} to {toU}") + raise ValueError( + f"The unit conversion is not from the same group: {frU} to {toU}" + ) # Calculate conversion factor val = frV / toV @@ -169,13 +176,14 @@ def unit_convert(fr, to, opts=None, tbl=unit_table): @set_module("sisl.unit") class UnitParser: - """ Object for converting between units for a set of unit-tables. + """Object for converting between units for a set of unit-tables. Parameters ---------- unit_table : dict a table with the units parsable by the class """ + __slots__ = ("_table", "_p_left", "_left", "_p_right", "_right") def __init__(self, table): @@ -193,7 +201,9 @@ def group(unit): for k in tbl: if unit in tbl[k]: return k - raise ValueError(f"The unit ""{unit!s}"" could not be located in the table.") + raise ValueError( + f"The unit " "{unit!s}" " could not be located in the table." + ) def default(group): tbl = self._table @@ -209,10 +219,11 @@ def default(group): @staticmethod def create_parser(value, default, group, group_table=None): - """ Routine to internally create a parser with specified unit_convert, unit_default and unit_group routines """ + """Routine to internally create a parser with specified unit_convert, unit_default and unit_group routines""" # Any length of characters will be used as a word. if group_table is None: + def _value(t): return value(t[0]) @@ -220,6 +231,7 @@ def _float(t): return float(t[0]) else: + def _value(t): group_table.append(group(t[0])) return value(t[0]) @@ -240,18 +252,25 @@ def _float(t): exponent = pp.Combine(e + sign_integer) sign_integer = pp.Combine(pp.Optional(plusorminus) + integer) exponent = pp.Combine(e + sign_integer) - number = pp.Or([pp.Combine(point + integer + pp.Optional(exponent)), # .[0-9][E+-[0-9]] - pp.Combine(integer + pp.Optional(point + pp.Optional(integer)) + pp.Optional(exponent))] # [0-9].[0-9][E+-[0-9]] + number = pp.Or( + [ + pp.Combine(point + integer + pp.Optional(exponent)), # .[0-9][E+-[0-9]] + pp.Combine( + integer + + pp.Optional(point + pp.Optional(integer)) + + pp.Optional(exponent) + ), + ] # [0-9].[0-9][E+-[0-9]] ).setParseAction(_float) - #def _print_toks(name, op): + # def _print_toks(name, op): # """ May be used in pow_op.setParseAction(_print_toks("pow", "^")) to debug """ # def T(t): # print("{}: {}".format(name, t)) # return op # return T - #def _fix_toks(op): + # def _fix_toks(op): # """ May be used in pow_op.setParseAction(_print_toks("pow", "^")) to debug """ # def T(t): # return op @@ -265,6 +284,7 @@ def _float(t): base_op = pp.Empty() if group_table is None: + def pow_action(toks): return toks[0][0] ** toks[0][2] @@ -278,11 +298,12 @@ def base_action(toks): return toks[0][0] * toks[0][1] else: + def pow_action(toks): # Fix table of units group = "{}^{}".format(group_table[-2], group_table.pop()) group_table[-1] = group - #print("^", toks[0], group_table) + # print("^", toks[0], group_table) return toks[0][0] ** toks[0][2] def mul_action(toks): @@ -290,7 +311,7 @@ def mul_action(toks): group_table.pop(-2) if isinstance(group_table[-1], float): group_table.pop() - #print("*", toks[0], group_table) + # print("*", toks[0], group_table) return toks[0][0] * toks[0][2] def div_action(toks): @@ -300,7 +321,7 @@ def div_action(toks): group_table.pop() else: group_table[-1] = "/{}".format(group_table[-1]) - #print("/", toks[0]) + # print("/", toks[0]) return toks[0][0] / toks[0][2] def base_action(toks): @@ -311,17 +332,21 @@ def base_action(toks): return toks[0][0] * toks[0][1] # We should parse numbers first - parser = pp.infixNotation(number | unit, - [(pow_op, 2, pp.opAssoc.RIGHT, pow_action), - (mul_op, 2, pp.opAssoc.LEFT, mul_action), - (div_op, 2, pp.opAssoc.LEFT, div_action), - (base_op, 2, pp.opAssoc.LEFT, base_action)]) + parser = pp.infixNotation( + number | unit, + [ + (pow_op, 2, pp.opAssoc.RIGHT, pow_action), + (mul_op, 2, pp.opAssoc.LEFT, mul_action), + (div_op, 2, pp.opAssoc.LEFT, div_action), + (base_op, 2, pp.opAssoc.LEFT, base_action), + ], + ) return parser @staticmethod def same_group(A, B): - """ Return true if A and B have the same groups """ + """Return true if A and B have the same groups""" A.sort() B.sort() if len(A) != len(B): @@ -329,7 +354,7 @@ def same_group(A, B): return all(a == b for a, b in zip(A, B)) def _convert(self, A, B): - """ Internal routine used to convert unit `A` to unit `B` """ + """Internal routine used to convert unit `A` to unit `B`""" conv_A = self._p_left.parseString(A)[0] conv_B = self._p_right.parseString(B)[0] if not self.same_group(self._left, self._right): @@ -337,13 +362,15 @@ def _convert(self, A, B): right = list(self._right) self._left.clear() self._right.clear() - raise ValueError(f"The unit conversion is not from the same group: {left} to {right}!") + raise ValueError( + f"The unit conversion is not from the same group: {left} to {right}!" + ) self._left.clear() self._right.clear() return conv_A / conv_B def convert(self, *units): - """ Conversion factors between units + """Conversion factors between units If 1 unit is passed a conversion to the default will be returned. If 2 parameters are passed then a single float will be returned that converts from diff --git a/src/sisl/unit/siesta.py b/src/sisl/unit/siesta.py index 76dcc9ecc9..96513f3e4b 100644 --- a/src/sisl/unit/siesta.py +++ b/src/sisl/unit/siesta.py @@ -18,50 +18,65 @@ from .base import unit_group as u_group from .base import unit_table -__all__ = ["unit_group", "unit_convert", "unit_default", "units"] - - -register_environ_variable("SISL_UNIT_SIESTA", "codata2018", - "Choose default units used when parsing Siesta files. [codata2018, legacy]", - process=str.lower) - - -unit_table_siesta_codata2018 = dict({key: dict(values) for key, values in unit_table.items()}) -unit_table_siesta_legacy = dict({key: dict(values) for key, values in unit_table.items()}) - -unit_table_siesta_legacy["length"].update({ - "Bohr": 0.529177e-10, -}) - -unit_table_siesta_legacy["time"].update({ - "mins": 60., - "hours": 3600., - "days": 86400., -}) - -unit_table_siesta_legacy["energy"].update({ - "meV": 1.60219e-22, - "eV": 1.60219e-19, - "mRy": 2.17991e-21, - "Ry": 2.17991e-18, - "mHa": 4.35982e-21, - "Ha": 4.35982e-18, - "Hartree": 4.35982e-18, - "K": 1.38066e-23, - "kJ/mol": 1.6606e-21, - "Hz": 6.6262e-34, - "THz": 6.6262e-22, - "cm-1": 1.986e-23, - "cm**-1": 1.986e-23, - "cm^-1": 1.986e-23, -}) - -unit_table_siesta_legacy["force"].update({ - "eV/Ang": 1.60219e-9, - "eV/Bohr": 1.60219e-9*0.529177, - "Ry/Bohr": 4.11943e-8, - "Ry/Ang": 4.11943e-8/0.529177, -}) +__all__ = ["unit_group", "unit_convert", "unit_default", "units"] + + +register_environ_variable( + "SISL_UNIT_SIESTA", + "codata2018", + "Choose default units used when parsing Siesta files. [codata2018, legacy]", + process=str.lower, +) + + +unit_table_siesta_codata2018 = dict( + {key: dict(values) for key, values in unit_table.items()} +) +unit_table_siesta_legacy = dict( + {key: dict(values) for key, values in unit_table.items()} +) + +unit_table_siesta_legacy["length"].update( + { + "Bohr": 0.529177e-10, + } +) + +unit_table_siesta_legacy["time"].update( + { + "mins": 60.0, + "hours": 3600.0, + "days": 86400.0, + } +) + +unit_table_siesta_legacy["energy"].update( + { + "meV": 1.60219e-22, + "eV": 1.60219e-19, + "mRy": 2.17991e-21, + "Ry": 2.17991e-18, + "mHa": 4.35982e-21, + "Ha": 4.35982e-18, + "Hartree": 4.35982e-18, + "K": 1.38066e-23, + "kJ/mol": 1.6606e-21, + "Hz": 6.6262e-34, + "THz": 6.6262e-22, + "cm-1": 1.986e-23, + "cm**-1": 1.986e-23, + "cm^-1": 1.986e-23, + } +) + +unit_table_siesta_legacy["force"].update( + { + "eV/Ang": 1.60219e-9, + "eV/Bohr": 1.60219e-9 * 0.529177, + "Ry/Bohr": 4.11943e-8, + "Ry/Ang": 4.11943e-8 / 0.529177, + } +) # Check for the correct handlers @@ -71,7 +86,9 @@ elif _def_unit in ("legacy", "original"): unit_table_siesta = unit_table_siesta_legacy else: - raise ValueError(f"Could not understand SISL_UNIT_SIESTA={_def_unit}, expected one of [codata2018, legacy]") + raise ValueError( + f"Could not understand SISL_UNIT_SIESTA={_def_unit}, expected one of [codata2018, legacy]" + ) @set_module("sisl.unit.siesta") diff --git a/src/sisl/unit/tests/test_unit.py b/src/sisl/unit/tests/test_unit.py index 757c3b2346..72749b0289 100644 --- a/src/sisl/unit/tests/test_unit.py +++ b/src/sisl/unit/tests/test_unit.py @@ -13,60 +13,62 @@ def test_group(): - assert unit_group('kg') == 'mass' - assert unit_group('eV') == 'energy' - assert unit_group('N') == 'force' + assert unit_group("kg") == "mass" + assert unit_group("eV") == "energy" + assert unit_group("N") == "force" def test_unit_convert(): - assert approx(unit_convert('kg', 'g')) == 1.e3 - assert approx(unit_convert('eV', 'J')) == 1.60217733e-19 - assert approx(unit_convert('J', 'eV')) == 1/1.60217733e-19 - assert approx(unit_convert('J', 'eV', opts={'^': 2})) == (1/1.60217733e-19) ** 2 - assert approx(unit_convert('J', 'eV', opts={'/': 2})) == (1/1.60217733e-19) / 2 - assert approx(unit_convert('J', 'eV', opts={'*': 2})) == (1/1.60217733e-19) * 2 + assert approx(unit_convert("kg", "g")) == 1.0e3 + assert approx(unit_convert("eV", "J")) == 1.60217733e-19 + assert approx(unit_convert("J", "eV")) == 1 / 1.60217733e-19 + assert approx(unit_convert("J", "eV", opts={"^": 2})) == (1 / 1.60217733e-19) ** 2 + assert approx(unit_convert("J", "eV", opts={"/": 2})) == (1 / 1.60217733e-19) / 2 + assert approx(unit_convert("J", "eV", opts={"*": 2})) == (1 / 1.60217733e-19) * 2 def test_class_unit(): - assert np.allclose(units.convert('J', 'J', 'J'), 1) - - assert approx(units.convert('kg', 'g')) == 1.e3 - assert approx(units.convert('eV', 'J')) == 1.60217733e-19 - assert approx(units.convert('J', 'eV')) == 1/1.60217733e-19 - assert approx(units.convert('J^2', 'eV**2')) == (1/1.60217733e-19) ** 2 - assert approx(units.convert('J/2', 'eV/2')) == (1/1.60217733e-19) - assert approx(units.convert('J', 'eV/2')) == (1/1.60217733e-19) * 2 - assert approx(units.convert('J2', '2eV')) == (1/1.60217733e-19) - assert approx(units.convert('J2', 'eV')) == (1/1.60217733e-19) * 2 - assert approx(units.convert('J/m', 'eV/Ang')) == unit_convert('J', 'eV') / unit_convert('m', 'Ang') - units('J**eV', 'eV**eV') - units('J/m', 'eV/m') + assert np.allclose(units.convert("J", "J", "J"), 1) + + assert approx(units.convert("kg", "g")) == 1.0e3 + assert approx(units.convert("eV", "J")) == 1.60217733e-19 + assert approx(units.convert("J", "eV")) == 1 / 1.60217733e-19 + assert approx(units.convert("J^2", "eV**2")) == (1 / 1.60217733e-19) ** 2 + assert approx(units.convert("J/2", "eV/2")) == (1 / 1.60217733e-19) + assert approx(units.convert("J", "eV/2")) == (1 / 1.60217733e-19) * 2 + assert approx(units.convert("J2", "2eV")) == (1 / 1.60217733e-19) + assert approx(units.convert("J2", "eV")) == (1 / 1.60217733e-19) * 2 + assert approx(units.convert("J/m", "eV/Ang")) == unit_convert( + "J", "eV" + ) / unit_convert("m", "Ang") + units("J**eV", "eV**eV") + units("J/m", "eV/m") def test_default(): - assert unit_default('mass') == 'amu' - assert unit_default('energy') == 'eV' - assert unit_default('force') == 'eV/Ang' + assert unit_default("mass") == "amu" + assert unit_default("energy") == "eV" + assert unit_default("force") == "eV/Ang" def test_group_f1(): with pytest.raises(ValueError): - unit_group('not-existing') + unit_group("not-existing") def test_default_f1(): with pytest.raises(ValueError): - unit_default('not-existing') + unit_default("not-existing") def test_unit_convert_f1(): with pytest.raises(ValueError): - unit_convert('eV', 'megaerg') + unit_convert("eV", "megaerg") def test_unit_convert_f2(): with pytest.raises(ValueError): - unit_convert('eV', 'kg') + unit_convert("eV", "kg") def test_unit_convert_single(): diff --git a/src/sisl/unit/tests/test_unit_siesta.py b/src/sisl/unit/tests/test_unit_siesta.py index 0ca3d26d54..b1010055d3 100644 --- a/src/sisl/unit/tests/test_unit_siesta.py +++ b/src/sisl/unit/tests/test_unit_siesta.py @@ -11,41 +11,41 @@ def test_group(): - assert unit_group('kg') == 'mass' - assert unit_group('eV') == 'energy' - assert unit_group('N') == 'force' + assert unit_group("kg") == "mass" + assert unit_group("eV") == "energy" + assert unit_group("N") == "force" def test_unit_convert(): - assert approx(unit_convert('kg', 'g')) == 1.e3 - assert approx(unit_convert('eV', 'J')) == 1.602176634e-19 - assert approx(unit_convert('J', 'eV')) == 1/1.602176634e-19 - assert approx(unit_convert('J', 'eV', {'^': 2})) == (1/1.602176634e-19) ** 2 - assert approx(unit_convert('J', 'eV', {'/': 2})) == (1/1.602176634e-19) / 2 - assert approx(unit_convert('J', 'eV', {'*': 2})) == (1/1.602176634e-19) * 2 + assert approx(unit_convert("kg", "g")) == 1.0e3 + assert approx(unit_convert("eV", "J")) == 1.602176634e-19 + assert approx(unit_convert("J", "eV")) == 1 / 1.602176634e-19 + assert approx(unit_convert("J", "eV", {"^": 2})) == (1 / 1.602176634e-19) ** 2 + assert approx(unit_convert("J", "eV", {"/": 2})) == (1 / 1.602176634e-19) / 2 + assert approx(unit_convert("J", "eV", {"*": 2})) == (1 / 1.602176634e-19) * 2 def test_default(): - assert unit_default('mass') == 'amu' - assert unit_default('energy') == 'eV' - assert unit_default('force') == 'eV/Ang' + assert unit_default("mass") == "amu" + assert unit_default("energy") == "eV" + assert unit_default("force") == "eV/Ang" def test_group_f1(): with pytest.raises(ValueError): - unit_group('not-existing') + unit_group("not-existing") def test_default_f1(): with pytest.raises(ValueError): - unit_default('not-existing') + unit_default("not-existing") def test_unit_convert_f1(): with pytest.raises(ValueError): - unit_convert('eV', 'megaerg') + unit_convert("eV", "megaerg") def test_unit_convert_f2(): with pytest.raises(ValueError): - unit_convert('eV', 'kg') + unit_convert("eV", "kg") diff --git a/src/sisl/utils/_arrays.py b/src/sisl/utils/_arrays.py index f2b8ce1157..ad081b0b4a 100644 --- a/src/sisl/utils/_arrays.py +++ b/src/sisl/utils/_arrays.py @@ -8,7 +8,7 @@ def batched_indices(ref, y, axis=-1, atol=1e-8, batch_size=200, diff_func=None): - """ Locate `x` in `ref` by examining ``np.abs(diff_func(ref - y)) <= atol`` + """Locate `x` in `ref` by examining ``np.abs(diff_func(ref - y)) <= atol`` This method is necessary for very large groups of data since the ``ref-y`` calls will use broad-casting to create very large memory chunks. @@ -24,7 +24,7 @@ def batched_indices(ref, y, axis=-1, atol=1e-8, batch_size=200, diff_func=None): ref : array_like reference array where we wish to locate the indices of `y` in y : array_like of 1D or 2D - array to locate in `ref`. For 2D arrays and `axis` not None, + array to locate in `ref`. For 2D arrays and `axis` not None, axis : int or None, optional which axis to do a logical reduction along, if `None` it means that they are 1D arrays and no axis will be reduced, i.e. same as ``ref.ravel() - y.reshape(-1, 1)`` @@ -56,8 +56,9 @@ def batched_indices(ref, y, axis=-1, atol=1e-8, batch_size=200, diff_func=None): n = max(1, n) if diff_func is None: + def diff_func(d): - """ Do nothing """ + """Do nothing""" return d def yield_cs(n, size): @@ -74,10 +75,14 @@ def yield_cs(n, size): # determine the batch size if axis is None: if atol.size != 1: - raise ValueError(f"batched_indices: for 1D comparisons atol can only be a single number.") + raise ValueError( + f"batched_indices: for 1D comparisons atol can only be a single number." + ) if y.ndim != 1: - raise ValueError(f"batched_indices: axis is None and y.ndim != 1 ({y.ndim}). For ravel comparisons the " - "dimensionality of y must be 1D.") + raise ValueError( + f"batched_indices: axis is None and y.ndim != 1 ({y.ndim}). For ravel comparisons the " + "dimensionality of y must be 1D." + ) # a 1D array comparison # we might as well ravel y (here to ensure we do not @@ -100,11 +105,15 @@ def yield_cs(n, size): # y must have 2 dimensions (or 1 with the same size as ref.shape[axis]) if y.ndim == 1: if y.size != ref.shape[axis]: - raise ValueError(f"batched_indices: when y is a single value it must have same length as ref.shape[axis]") + raise ValueError( + f"batched_indices: when y is a single value it must have same length as ref.shape[axis]" + ) y = y.reshape(1, -1) elif y.ndim == 2: if y.shape[1] != ref.shape[axis]: - raise ValueError(f"batched_indices: the comparison axis of y (y[0, :]) should have the same length as ref.shape[axis]") + raise ValueError( + f"batched_indices: the comparison axis of y (y[0, :]) should have the same length as ref.shape[axis]" + ) else: raise ValueError(f"batched_indices: y should be either 1D or 2D") @@ -118,14 +127,17 @@ def yield_cs(n, size): if atol.size > 1: if atol.size != ref.shape[axis]: - raise ValueError(f"batched_indices: atol size does not match the axis {axis} for ref argument.") + raise ValueError( + f"batched_indices: atol size does not match the axis {axis} for ref argument." + ) atol = np.expand_dims(atol.ravel(), tuple(range(ref.ndim - 1))) atol = np.moveaxis(atol, -1, axis) # b-cast size is for idx in yield_cs(n, y.shape[-1]): idx = np.logical_and.reduce( - np.abs(diff_func(ref - y[..., idx])) <= atol, axis=axis).nonzero()[:-1] + np.abs(diff_func(ref - y[..., idx])) <= atol, axis=axis + ).nonzero()[:-1] indices.append(idx) # concatenate each indices array diff --git a/src/sisl/utils/_sisl_cmd.py b/src/sisl/utils/_sisl_cmd.py index 2305076192..2a5b005b5b 100644 --- a/src/sisl/utils/_sisl_cmd.py +++ b/src/sisl/utils/_sisl_cmd.py @@ -11,7 +11,7 @@ def argparse_patch(parser): - """ Patch the argparse module such that one may process the Namespace in subparsers + """Patch the argparse module such that one may process the Namespace in subparsers This patch have been created by: paul.j3 (http://bugs.python.org/file44363/issue27859test.py) @@ -22,8 +22,8 @@ def argparse_patch(parser): parser : ArgumentParser parser to be patched """ - class MySubParsersAction(argparse._SubParsersAction): + class MySubParsersAction(argparse._SubParsersAction): def __call__(self, parser, namespace, values, option_string=None): parser_name = values[0] arg_strings = values[1:] @@ -36,9 +36,11 @@ def __call__(self, parser, namespace, values, option_string=None): try: parser = self._name_parser_map[parser_name] except KeyError: - args = {'parser_name': parser_name, - 'choices': ', '.join(self._name_parser_map)} - msg = ('unknown parser %(parser_name)r (choices: %(choices)s)') % args + args = { + "parser_name": parser_name, + "choices": ", ".join(self._name_parser_map), + } + msg = ("unknown parser %(parser_name)r (choices: %(choices)s)") % args raise argparse.ArgumentError(self, msg) # parse all the remaining options into the namespace @@ -50,14 +52,15 @@ def __call__(self, parser, namespace, values, option_string=None): namespace, arg_strings = parser.parse_known_args(arg_strings, namespace) ## ORIGINAL - #subnamespace, arg_strings = parser.parse_known_args(arg_strings, None) - #for key, value in vars(subnamespace).items(): + # subnamespace, arg_strings = parser.parse_known_args(arg_strings, None) + # for key, value in vars(subnamespace).items(): # setattr(namespace, key, value) if arg_strings: vars(namespace).setdefault(argparse._UNRECOGNIZED_ARGS_ATTR, []) getattr(namespace, argparse._UNRECOGNIZED_ARGS_ATTR).extend(arg_strings) - parser.register('action', 'parsers', MySubParsersAction) + + parser.register("action", "parsers", MySubParsersAction) def sisl_cmd(argv=None, sile=None): @@ -79,7 +82,7 @@ def sisl_cmd(argv=None, sile=None): elif len(sys.argv) == 1: # no arguments # fake a help - argv = ['--help'] + argv = ["--help"] else: argv = sys.argv[1:] @@ -92,9 +95,12 @@ def sisl_cmd(argv=None, sile=None): # Ensure that the arguments have pre-pended spaces argv = cmd.argv_negative_fix(argv) - p = argparse.ArgumentParser(exe, - formatter_class=argparse.RawDescriptionHelpFormatter, - description=description, conflict_handler='resolve') + p = argparse.ArgumentParser( + exe, + formatter_class=argparse.RawDescriptionHelpFormatter, + description=description, + conflict_handler="resolve", + ) # Add default sisl version stuff cmd.add_sisl_version_cite_arg(p) @@ -108,8 +114,8 @@ def sisl_cmd(argv=None, sile=None): # Now the arguments should have been populated # and we will sort out if the input options # is only a help option. - if not hasattr(ns, '_input_file'): - bypassed_args = ['--help', '-h', '--version', '--cite'] + if not hasattr(ns, "_input_file"): + bypassed_args = ["--help", "-h", "--version", "--cite"] # Then there are no input files... # It is difficult to create an adaptable script # with no adaptee... ;) @@ -118,7 +124,7 @@ def sisl_cmd(argv=None, sile=None): found = found or arg in argv if not found: # Re-create the argument parser with the help description - argv = ['--help'] + argv = ["--help"] # We are good to go!!! p.parse_args(argv, namespace=ns) diff --git a/src/sisl/utils/cmd.py b/src/sisl/utils/cmd.py index 47c57ecfa3..4ff8089bac 100644 --- a/src/sisl/utils/cmd.py +++ b/src/sisl/utils/cmd.py @@ -5,17 +5,17 @@ from sisl.utils.ranges import strmap, strseq -__all__ = ['argv_negative_fix', 'default_namespace'] -__all__ += ['collect_input', 'collect_arguments'] -__all__ += ['add_sisl_version_cite_arg'] -__all__ += ['default_ArgumentParser'] -__all__ += ['collect_action', 'run_collect_action'] -__all__ += ['run_actions'] -__all__ += ['add_action'] +__all__ = ["argv_negative_fix", "default_namespace"] +__all__ += ["collect_input", "collect_arguments"] +__all__ += ["add_sisl_version_cite_arg"] +__all__ += ["default_ArgumentParser"] +__all__ += ["collect_action", "run_collect_action"] +__all__ += ["run_actions"] +__all__ += ["add_action"] def argv_negative_fix(argv): - """ Fixes `argv` list by adding a space for input that may be float's + """Fixes `argv` list by adding a space for input that may be float's This function tries to prevent ``'-<>'`` being captured by `argparse`. @@ -32,20 +32,22 @@ def argv_negative_fix(argv): except Exception: rgv.append(a) else: - rgv.append(' ' + a) + rgv.append(" " + a) return rgv def default_namespace(**kwargs): - """ Ensure the namespace can be used to collect and run the actions + """Ensure the namespace can be used to collect and run the actions Parameters ---------- **kwargs : dict the dictionary keys added to the namespace object. """ + class CustomNamespace: pass + namespace = CustomNamespace() namespace._actions_run = False namespace._actions = [] @@ -55,7 +57,7 @@ class CustomNamespace: def add_action(namespace, action, args, kwargs): - """ Add an action to the list of actions to be runned + """Add an action to the list of actions to be runned Parameters ---------- @@ -72,7 +74,7 @@ def add_action(namespace, action, args, kwargs): def collect_input(argv): - """ Function for returning the input file + """Function for returning the input file This simply creates a shortcut input file and returns it. @@ -83,9 +85,9 @@ def collect_input(argv): arguments passed to an `argparse.ArgumentParser` """ # Grap input-file - p = argparse.ArgumentParser('Parser for input file', add_help=False) + p = argparse.ArgumentParser("Parser for input file", add_help=False) # Now add the input and output file - p.add_argument('input_file', nargs='?', default=None) + p.add_argument("input_file", nargs="?", default=None) # Retrieve the input file # (return the remaining options) args, argv = p.parse_known_args(argv) @@ -94,7 +96,7 @@ def collect_input(argv): def add_sisl_version_cite_arg(parser): - """ Add a sisl version and citation argument to the ArgumentParser for printing (to stdout) the used sisl version + """Add a sisl version and citation argument to the ArgumentParser for printing (to stdout) the used sisl version Parameters ---------- @@ -108,20 +110,28 @@ def add_sisl_version_cite_arg(parser): class PrintVersion(argparse.Action): def __call__(self, parser, ns, values, option_string=None): print(f"sisl: {__version__}") - group.add_argument('--version', nargs=0, action=PrintVersion, - help=f'Show detailed sisl version information (v{__version__})') + + group.add_argument( + "--version", + nargs=0, + action=PrintVersion, + help=f"Show detailed sisl version information (v{__version__})", + ) class PrintCite(argparse.Action): def __call__(self, parser, ns, values, option_string=None): print(f"BibTeX:\n{__bibtex__}") - group.add_argument('--cite', nargs=0, action=PrintCite, - help='Show the citation required when using sisl') + group.add_argument( + "--cite", + nargs=0, + action=PrintCite, + help="Show the citation required when using sisl", + ) -def collect_arguments(argv, input=False, - argumentparser=None, - namespace=None): - """ Function for returning the actual arguments depending on the input options. + +def collect_arguments(argv, input=False, argumentparser=None, namespace=None): + """Function for returning the actual arguments depending on the input options. This function will create a fake `argparse.ArgumentParser` which then will pass through the input figuring out which options @@ -154,8 +164,8 @@ def collect_arguments(argv, input=False, input_file = None # Grap output-file - p = argparse.ArgumentParser('Parser for output file', add_help=False) - p.add_argument('--out', '-o', nargs=1, default=None) + p = argparse.ArgumentParser("Parser for output file", add_help=False) + p.add_argument("--out", "-o", nargs=1, default=None) # Parse the passed args to sort out the input file and # the output file @@ -164,17 +174,20 @@ def collect_arguments(argv, input=False, if input_file is not None: try: obj = get_sile(input_file) - argumentparser, namespace = obj.ArgumentParser(argumentparser, namespace=namespace, - **obj._ArgumentParser_args_single()) + argumentparser, namespace = obj.ArgumentParser( + argumentparser, namespace=namespace, **obj._ArgumentParser_args_single() + ) # Be sure to add the input file - setattr(namespace, '_input_file', input_file) + setattr(namespace, "_input_file", input_file) except Exception as e: print(e) - raise ValueError(f"File: '{input_file}' cannot be found. Please supply a readable file!") + raise ValueError( + f"File: '{input_file}' cannot be found. Please supply a readable file!" + ) if args.out is not None: try: - obj = get_sile(args.out[0], mode='r') + obj = get_sile(args.out[0], mode="r") obj.ArgumentParser_out(argumentparser, namespace=namespace) except Exception: # Allowed pass due to pythonic reading @@ -188,6 +201,7 @@ def default_ArgumentParser(*A_args, **A_kwargs): Decorator for routines which takes a parser as argument and ensures that it is _not_ ``None``. """ + def default_AP(func): # This requires that the first argument # for the function is the parser with default=None @@ -199,7 +213,9 @@ def new_func(self, parser=None, *args, **kwargs): parser.description = A_kwargs["description"] return func(self, parser, *args, **kwargs) + return new_func + return default_AP @@ -216,6 +232,7 @@ def collect(self, *args, **kwargs): # Else we append the actions to be performed add_action(args[1], self, args, kwargs) return None + return collect @@ -231,6 +248,7 @@ def collect(self, *args, **kwargs): return func(self, *args, **kwargs) add_action(args[1], self, args, kwargs) return func(self, *args, **kwargs) + return collect @@ -249,4 +267,5 @@ def run(self, *args, **kwargs): args[1]._actions_run = False args[1]._actions = [] return func(self, *args, **kwargs) + return run diff --git a/src/sisl/utils/mathematics.py b/src/sisl/utils/mathematics.py index 3b2c7932fd..715fb47a61 100644 --- a/src/sisl/utils/mathematics.py +++ b/src/sisl/utils/mathematics.py @@ -27,7 +27,7 @@ def fnorm(array, axis=-1): - r""" Fast calculation of the norm of a vector + r"""Fast calculation of the norm of a vector Parameters ---------- @@ -40,7 +40,7 @@ def fnorm(array, axis=-1): def fnorm2(array, axis=-1): - r""" Fast calculation of the squared norm of a vector + r"""Fast calculation of the squared norm of a vector Parameters ---------- @@ -53,7 +53,7 @@ def fnorm2(array, axis=-1): def expand(vector, length): - r""" Expand `vector` by `length` such that the norm of the vector is increased by `length` + r"""Expand `vector` by `length` such that the norm of the vector is increased by `length` The expansion of the vector can be written as: @@ -75,7 +75,7 @@ def expand(vector, length): def orthogonalize(ref, vector): - r""" Ensure `vector` is orthogonal to `ref`, `vector` must *not* be parallel to `ref`. + r"""Ensure `vector` is orthogonal to `ref`, `vector` must *not* be parallel to `ref`. Enable an easy creation of a vector orthogonal to a reference vector. The length of the vector is not necessarily preserved (if they are not orthogonal). @@ -107,13 +107,15 @@ def orthogonalize(ref, vector): nr = fnorm(ref) vector = asarray(vector).ravel() d = dot(ref, vector) / nr - if abs(1. - abs(d) / fnorm(vector)) < 1e-7: - raise ValueError(f"orthogonalize: requires non-parallel vectors to perform an orthogonalization: ref.vector = {d}") + if abs(1.0 - abs(d) / fnorm(vector)) < 1e-7: + raise ValueError( + f"orthogonalize: requires non-parallel vectors to perform an orthogonalization: ref.vector = {d}" + ) return vector - ref * d / nr def spher2cart(r, theta, phi): - r""" Convert spherical coordinates to cartesian coordinates + r"""Convert spherical coordinates to cartesian coordinates Parameters ---------- @@ -140,7 +142,7 @@ def spher2cart(r, theta, phi): def cart2spher(r, theta=True, cos_phi=False, maxR=None): - r""" Transfer a vector to spherical coordinates with some possible differences + r"""Transfer a vector to spherical coordinates with some possible differences Parameters ---------- @@ -182,11 +184,11 @@ def cart2spher(r, theta=True, cos_phi=False, maxR=None): phi = r[:, 2] / rr else: phi = arccos(r[:, 2] / rr) - phi[rr == 0.] = 0. + phi[rr == 0.0] = 0.0 return rr, theta, phi rr = square(r).sum(-1) - idx = indices_le(rr, maxR ** 2) + idx = indices_le(rr, maxR**2) r = take(r, idx, 0) rr = sqrt(take(rr, idx)) if theta: @@ -199,12 +201,12 @@ def cart2spher(r, theta=True, cos_phi=False, maxR=None): phi = arccos(r[:, 2] / rr) # Typically there will be few rr==0. values, so no need to # create indices - phi[rr == 0.] = 0. + phi[rr == 0.0] = 0.0 return n, idx, rr, theta, phi def spherical_harm(m, l, theta, phi): - r""" Calculate the spherical harmonics using :math:`Y_l^m(\theta, \varphi)` with :math:`\mathbf R\to \{r, \theta, \varphi\}`. + r"""Calculate the spherical harmonics using :math:`Y_l^m(\theta, \varphi)` with :math:`\mathbf R\to \{r, \theta, \varphi\}`. .. math:: Y^m_l(\theta,\varphi) = (-1)^m\sqrt{\frac{2l+1}{4\pi} \frac{(l-m)!}{(l+m)!}} @@ -224,7 +226,7 @@ def spherical_harm(m, l, theta, phi): angle from :math:`z` axis (polar) """ # Probably same as: - #return (-1) ** m * ( (2*l+1)/(4*pi) * factorial(l-m) / factorial(l+m) ) ** 0.5 \ + # return (-1) ** m * ( (2*l+1)/(4*pi) * factorial(l-m) / factorial(l+m) ) ** 0.5 \ # * lpmv(m, l, cos(theta)) * exp(1j * m * phi) return sph_harm(m, l, theta, phi) * (-1) ** m @@ -308,7 +310,7 @@ def intersect_and_diff_sets(a, b): This saves a bit compared to doing np.delete() afterwards. """ aux = concatenate((a, b)) - aux_sort_indices = argsort(aux, kind='mergesort') + aux_sort_indices = argsort(aux, kind="mergesort") aux = aux[aux_sort_indices] # find elements that are the same in both arrays # after sorting we should have at most 2 same elements @@ -322,7 +324,7 @@ def intersect_and_diff_sets(a, b): no_buddy = nobuddy_lr[:-1] # no match left no_buddy &= nobuddy_lr[1:] # no match right - aonly = (aux_sort_indices < a.size) + aonly = aux_sort_indices < a.size bonly = ~aonly aonly &= no_buddy bonly &= no_buddy diff --git a/src/sisl/utils/misc.py b/src/sisl/utils/misc.py index a78f04c314..cebb03409d 100644 --- a/src/sisl/utils/misc.py +++ b/src/sisl/utils/misc.py @@ -17,13 +17,19 @@ # supported operators -_operators = {ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul, - ast.Div: op.truediv, ast.Pow: op.pow, ast.BitXor: op.xor, - ast.USub: op.neg} +_operators = { + ast.Add: op.add, + ast.Sub: op.sub, + ast.Mult: op.mul, + ast.Div: op.truediv, + ast.Pow: op.pow, + ast.BitXor: op.xor, + ast.USub: op.neg, +} def math_eval(expr): - """ Evaluate a mathematical expression using a safe evaluation method + """Evaluate a mathematical expression using a safe evaluation method Parameters ---------- @@ -43,18 +49,18 @@ def math_eval(expr): def _eval(node): - if isinstance(node, ast.Num): # + if isinstance(node, ast.Num): # return node.n - elif isinstance(node, ast.BinOp): # + elif isinstance(node, ast.BinOp): # return _operators[type(node.op)](_eval(node.left), _eval(node.right)) - elif isinstance(node, ast.UnaryOp): # e.g., -1 + elif isinstance(node, ast.UnaryOp): # e.g., -1 return _operators[type(node.op)](_eval(node.operand)) else: raise TypeError(node) def merge_instances(*args, **kwargs): - """ Merges an arbitrary number of instances together. + """Merges an arbitrary number of instances together. Parameters ---------- @@ -76,7 +82,7 @@ def merge_instances(*args, **kwargs): def iter_shape(shape): - """ Generator for iterating a shape by returning consecutive slices + """Generator for iterating a shape by returning consecutive slices Parameters ---------- @@ -100,7 +106,7 @@ def iter_shape(shape): [1, 0, 1] [1, 0, 2] """ - shape1 = [i-1 for i in shape] + shape1 = [i - 1 for i in shape] ns = len(shape) ns1 = ns - 1 # Create list for iterating @@ -117,11 +123,11 @@ def iter_shape(shape): if slc[i] >= shape1[i]: slc[i] = 0 if i > 0: - slc[i-1] += 1 + slc[i - 1] += 1 def str_spec(name): - """ Split into a tuple of name and specifier, delimited by ``{...}``. + """Split into a tuple of name and specifier, delimited by ``{...}``. Parameters ---------- @@ -149,7 +155,7 @@ def str_spec(name): # Transform a string to a Cartesian direction def direction(d, abc=None, xyz=None): - """ Index coordinate corresponding to the Cartesian coordinate system. + """Index coordinate corresponding to the Cartesian coordinate system. Parameters ---------- @@ -206,12 +212,14 @@ def direction(d, abc=None, xyz=None): else: if d in ("x", "y", "z", "a", "b", "c", "0", "1", "2"): return "xa0yb1zc2".index(d) // 3 - raise ValueError("direction: Input direction is not an integer, nor a string in 'xyz/abc/012'") + raise ValueError( + "direction: Input direction is not an integer, nor a string in 'xyz/abc/012'" + ) # Transform an input to an angle def angle(s, rad=True, in_rad=True): - """ Convert the input string to an angle, either radians or degrees. + """Convert the input string to an angle, either radians or degrees. Parameters ---------- @@ -271,7 +279,7 @@ def angle(s, rad=True, in_rad=True): if in_rad: Pi = pi else: - Pi = 180. + Pi = 180.0 s = (f"{Pi}").join(spi) @@ -280,9 +288,9 @@ def angle(s, rad=True, in_rad=True): # the expression ra = math_eval(s) if rad and not in_rad: - return ra / 180. * pi + return ra / 180.0 * pi if not rad and in_rad: - return ra / pi * 180. + return ra / pi * 180.0 # Both radians and in_radians are equivalent # so return as-is @@ -290,7 +298,7 @@ def angle(s, rad=True, in_rad=True): def allow_kwargs(*args): - """ Decoractor for forcing `func` to have the named arguments as listed in `args` + """Decoractor for forcing `func` to have the named arguments as listed in `args` This decorator merely removes any keyword argument from the called function which is in the list of `args` in case the function does not have the arguments @@ -301,6 +309,7 @@ def allow_kwargs(*args): *args : str required arguments in `func`, if already present nothing will be done. """ + def deco(func): if func is None: return None @@ -342,7 +351,7 @@ def dec_func(*args, **kwargs): def import_attr(attr_path): - """ Returns an attribute from a full module path + """Returns an attribute from a full module path Examples -------- @@ -361,7 +370,7 @@ def import_attr(attr_path): def lazy_import(name, package=None): - """ Lazily import a module or submodule + """Lazily import a module or submodule Parameters ---------- @@ -404,7 +413,7 @@ def lazy_import(name, package=None): class PropertyDict(dict): - """ Simple dictionary which may access items as properties as well """ + """Simple dictionary which may access items as properties as well""" def __getattr__(self, name): try: @@ -420,7 +429,6 @@ def __dir__(self): class NotNonePropertyDict(PropertyDict): - def __setitem__(self, key, value): if value is None: return diff --git a/src/sisl/utils/ranges.py b/src/sisl/utils/ranges.py index b02d059a1e..5365442ee0 100644 --- a/src/sisl/utils/ranges.py +++ b/src/sisl/utils/ranges.py @@ -9,7 +9,7 @@ # Function to change a string to a range of integers def strmap(func, s, start=None, end=None, sep="b"): - """ Parse a string as though it was a slice and map all entries using ``func``. + """Parse a string as though it was a slice and map all entries using ``func``. Parameters ---------- @@ -65,8 +65,8 @@ def strmap(func, s, start=None, end=None, sep="b"): i = i + 1 else: # there must be more [ than ] - commas[i] = commas[i] + "," + commas[i+1] - del commas[i+1] + commas[i] = commas[i] + "," + commas[i + 1] + del commas[i + 1] # Check the last input... i = len(commas) - 1 @@ -77,7 +77,6 @@ def strmap(func, s, start=None, end=None, sep="b"): # with collected brackets. l = [] for seg in commas: - # Split it in groups of reg-exps m = segment.findall(seg)[0] @@ -89,8 +88,9 @@ def strmap(func, s, start=None, end=None, sep="b"): elif len(m[2]) > 0: # this is: ..[..] - l.append((strseq(func, m[2], start, end), - strmap(func, m[3], start, end, sep))) + l.append( + (strseq(func, m[2], start, end), strmap(func, m[3], start, end, sep)) + ) elif len(m[4]) > 0: l.append(strseq(func, m[4], start, end)) @@ -99,7 +99,7 @@ def strmap(func, s, start=None, end=None, sep="b"): def strseq(cast, s, start=None, end=None): - """ Accept a string and return the casted tuples of content based on ranges. + """Accept a string and return the casted tuples of content based on ranges. Parameters ---------- @@ -154,14 +154,14 @@ def strseq(cast, s, start=None, end=None): def erange(start, step, end=None): - """ Returns the range with both ends includede """ + """Returns the range with both ends includede""" if end is None: return range(start, step + 1) return range(start, end + 1, step) def lstranges(lst, cast=erange, end=None): - """ Convert a `strmap` list into expanded ranges """ + """Convert a `strmap` list into expanded ranges""" l = [] # If an entry is a tuple, it means it is either # a range 0-1 == tuple(0, 1), or @@ -196,7 +196,7 @@ def lstranges(lst, cast=erange, end=None): def list2str(lst): - """ Convert a list of elements into a string of ranges + """Convert a list of elements into a string of ranges Examples -------- @@ -219,10 +219,10 @@ def list2str(lst): t += ln if ln == 1: rng += str(el) - #elif ln == 2: + # elif ln == 2: # rng += "{}, {}".format(str(el), str(el+ln-1)) else: - rng += "{}-{}".format(el, el+ln-1) + rng += "{}-{}".format(el, el + ln - 1) return rng @@ -235,7 +235,7 @@ def list2str(lst): # file[0-1] returns # file, [0,1] def fileindex(f, cast=int): - """ Parses a filename string into the filename and the indices. + """Parses a filename string into the filename and the indices. This range can be formatted like this: file[1,2,3-6] diff --git a/src/sisl/utils/tests/test_cmd.py b/src/sisl/utils/tests/test_cmd.py index e191f02816..932b8d8238 100644 --- a/src/sisl/utils/tests/test_cmd.py +++ b/src/sisl/utils/tests/test_cmd.py @@ -11,16 +11,16 @@ def test_default_namespace1(): - d = {'a': 1} + d = {"a": 1} dd = default_namespace(**d) - assert dd.a == d['a'] + assert dd.a == d["a"] def test_collect_input1(): - argv = ['test.xyz', '--stneohus stnaoeu', '-a', 'aote'] + argv = ["test.xyz", "--stneohus stnaoeu", "-a", "aote"] argv_out, in_file = collect_input(argv) - assert in_file == 'test.xyz' + assert in_file == "test.xyz" assert len(argv) == len(argv_out) + 1 @@ -36,13 +36,12 @@ def test_collect_arguments2(): def test_collect_arguments3(): with pytest.raises(ValueError): - collect_arguments(['this.file.never.exists'], input=True) + collect_arguments(["this.file.never.exists"], input=True) def test_decorators1(): - # Create a default argument parser - @default_ArgumentParser('SPBS', description='MY DEFAULT STUFF') + @default_ArgumentParser("SPBS", description="MY DEFAULT STUFF") def myArgParser(self, p=None, *args, **kwargs): return p @@ -50,7 +49,7 @@ def myArgParser(self, p=None, *args, **kwargs): assert "SPBS" in p.format_help() assert "MY DEFAULT STUFF" in p.format_help() - p = argparse.ArgumentParser(description='SECOND DEFAULT') + p = argparse.ArgumentParser(description="SECOND DEFAULT") p = myArgParser(None, p) assert "SPBS" not in p.format_help() assert "MY DEFAULT STUFF" in p.format_help() @@ -58,60 +57,59 @@ def myArgParser(self, p=None, *args, **kwargs): def test_decorators2(): - p = argparse.ArgumentParser() - ns = default_namespace(my_default='test') + ns = default_namespace(my_default="test") class Act1(argparse.Action): @run_collect_action def __call__(self, parser, ns, value, option_string=None): - setattr(ns, 'act1', value) + setattr(ns, "act1", value) class Act2(argparse.Action): @run_collect_action def __call__(self, parser, ns, value, option_string=None): assert ns.act1 is not None - setattr(ns, 'act2', value) + setattr(ns, "act2", value) class Act3(argparse.Action): @collect_action def __call__(self, parser, ns, value, option_string=None): - setattr(ns, 'act3', value) + setattr(ns, "act3", value) class Act4(argparse.Action): def __call__(self, parser, ns, value, option_string=None): with pytest.raises(AttributeError): assert ns.act3 is None - setattr(ns, 'act4', value) + setattr(ns, "act4", value) class Act5(argparse.Action): @run_actions def __call__(self, parser, ns, value, option_string=None): pass - p.add_argument('--a1', action=Act1) - p.add_argument('--a2', action=Act2) - p.add_argument('--a3', action=Act3) - p.add_argument('--a4', action=Act4) - p.add_argument('--a5', action=Act5, nargs=0) + p.add_argument("--a1", action=Act1) + p.add_argument("--a2", action=Act2) + p.add_argument("--a3", action=Act3) + p.add_argument("--a4", action=Act4) + p.add_argument("--a5", action=Act5, nargs=0) # Run arguments - argv = '--a1 v1 --a2 v2 --a3 v3 --a4 v4'.split() + argv = "--a1 v1 --a2 v2 --a3 v3 --a4 v4".split() args = p.parse_args(argv, namespace=ns) - assert args.my_default == 'test' + assert args.my_default == "test" - assert args.act1 == 'v1' - assert args.act2 == 'v2' + assert args.act1 == "v1" + assert args.act2 == "v2" with pytest.raises(AttributeError): assert args.act3 is None - assert args.act4 == 'v4' + assert args.act4 == "v4" - args = p.parse_args(argv + ['--a5'], namespace=ns) + args = p.parse_args(argv + ["--a5"], namespace=ns) - assert args.my_default == 'test' + assert args.my_default == "test" - assert args.act1 == 'v1' - assert args.act2 == 'v2' - assert args.act3 == 'v3' - assert args.act4 == 'v4' + assert args.act1 == "v1" + assert args.act2 == "v2" + assert args.act3 == "v3" + assert args.act4 == "v4" diff --git a/src/sisl/utils/tests/test_misc.py b/src/sisl/utils/tests/test_misc.py index 7df6853629..909d776b01 100644 --- a/src/sisl/utils/tests/test_misc.py +++ b/src/sisl/utils/tests/test_misc.py @@ -19,25 +19,25 @@ def test_direction_int(): def test_direction_str(): - assert direction('A') == 0 - assert direction('B') == 1 - assert direction('C') == 2 - assert direction('a') == 0 - assert direction('b') == 1 - assert direction('c') == 2 - assert direction('X') == 0 - assert direction('Y') == 1 - assert direction('Z') == 2 - assert direction('x') == 0 - assert direction('y') == 1 - assert direction('z') == 2 - assert direction('0') == 0 - assert direction('1') == 1 - assert direction('2') == 2 - assert direction(' 0') == 0 - assert direction(' 1 ') == 1 - assert direction(' 2 ') == 2 - assert np.allclose(direction(' 2 ', abc=np.diag([1, 2, 3])), [0, 0, 3]) + assert direction("A") == 0 + assert direction("B") == 1 + assert direction("C") == 2 + assert direction("a") == 0 + assert direction("b") == 1 + assert direction("c") == 2 + assert direction("X") == 0 + assert direction("Y") == 1 + assert direction("Z") == 2 + assert direction("x") == 0 + assert direction("y") == 1 + assert direction("z") == 2 + assert direction("0") == 0 + assert direction("1") == 1 + assert direction("2") == 2 + assert direction(" 0") == 0 + assert direction(" 1 ") == 1 + assert direction(" 2 ") == 2 + assert np.allclose(direction(" 2 ", abc=np.diag([1, 2, 3])), [0, 0, 3]) def test_direction_int_raises(): @@ -47,28 +47,28 @@ def test_direction_int_raises(): def test_direction_str_raises(): with pytest.raises(ValueError): - direction('aosetuh') + direction("aosetuh") def test_angle_r2r(): - assert pytest.approx(angle('2pi')) == 2*m.pi - assert pytest.approx(angle('2pi/2')) == m.pi - assert pytest.approx(angle('3pi/4')) == 3*m.pi/4 + assert pytest.approx(angle("2pi")) == 2 * m.pi + assert pytest.approx(angle("2pi/2")) == m.pi + assert pytest.approx(angle("3pi/4")) == 3 * m.pi / 4 - assert pytest.approx(angle('a2*180')) == 2*m.pi - assert pytest.approx(angle('2*180', in_rad=False)) == 2*m.pi - assert pytest.approx(angle('a2*180r')) == 2*m.pi + assert pytest.approx(angle("a2*180")) == 2 * m.pi + assert pytest.approx(angle("2*180", in_rad=False)) == 2 * m.pi + assert pytest.approx(angle("a2*180r")) == 2 * m.pi def test_angle_a2a(): - assert pytest.approx(angle('a2pia')) == 360 - assert pytest.approx(angle('a2pi/2a')) == 180 - assert pytest.approx(angle('a3pi/4a')) == 3*180./4 + assert pytest.approx(angle("a2pia")) == 360 + assert pytest.approx(angle("a2pi/2a")) == 180 + assert pytest.approx(angle("a3pi/4a")) == 3 * 180.0 / 4 - assert pytest.approx(angle('a2pia', True, True)) == 360 - assert pytest.approx(angle('a2pi/2a', True, False)) == 180 - assert pytest.approx(angle('a2pi/2a', False, True)) == 180 - assert pytest.approx(angle('a2pi/2a', False, False)) == 180 + assert pytest.approx(angle("a2pia", True, True)) == 360 + assert pytest.approx(angle("a2pi/2a", True, False)) == 180 + assert pytest.approx(angle("a2pi/2a", False, True)) == 180 + assert pytest.approx(angle("a2pi/2a", False, False)) == 180 def test_iter1(): @@ -91,30 +91,35 @@ def test_iter1(): def test_str_spec1(): - a = str_spec('foo') - assert a[0] == 'foo' + a = str_spec("foo") + assert a[0] == "foo" assert a[1] is None - a = str_spec('foo{bar}') - assert a[0] == 'foo' - assert a[1] == 'bar' + a = str_spec("foo{bar}") + assert a[0] == "foo" + assert a[1] == "bar" def test_merge_instances1(): class A: pass + a = A() a.hello = 1 + class B: pass + b = B() b.hello = 2 b.foo = 2 + class C: pass + c = C() c.bar = 3 - d = merge_instances(a, b, c, name='TestClass') - assert d.__class__.__name__ == 'TestClass' + d = merge_instances(a, b, c, name="TestClass") + assert d.__class__.__name__ == "TestClass" assert d.hello == 2 assert d.foo == 2 assert d.bar == 3 diff --git a/src/sisl/utils/tests/test_ranges.py b/src/sisl/utils/tests/test_ranges.py index 423fb1f27c..41a106abdd 100644 --- a/src/sisl/utils/tests/test_ranges.py +++ b/src/sisl/utils/tests/test_ranges.py @@ -13,7 +13,6 @@ @pytest.mark.ranges class TestRanges: - def test_strseq_colon(self): ranges = strseq(int, "1:2:5") assert ranges == (1, 2, 5) @@ -41,101 +40,106 @@ def test_strseq_minus(self): assert ranges == (-4, None) def test_strmap1(self): - assert strmap(int, '1') == [1] - assert strmap(int, '') == [None] - assert strmap(int, '1-') == [(1, None)] - assert strmap(int, '-') == [(None, None)] - assert strmap(int, '-1') == [-1] - assert strmap(int, '-1', start=1) == [-1] - assert strmap(int, '-', start=1, end=2) == [(1, 2)] - assert strmap(int, '-1:2') == [(-1, 2)] - assert strmap(int, '1,2') == [1, 2] - assert strmap(int, '1,2[0,2-]') == [1, (2, [0, (2, None)])] - assert strmap(int, '1,2-[0,-2]') == [1, ((2, None), [0, -2])] - assert strmap(int, '1,2[0,2]') == [1, (2, [0, 2])] - assert strmap(int, '1,2-3[0,2]') == [1, ((2, 3), [0, 2])] - assert strmap(int, '1,[2,3][0,2]') == [1, (2, [0, 2]), (3, [0, 2])] - assert strmap(int, '[82][10]') == [(82, [10])] - assert strmap(int, '[82,83][10]') == [(82, [10]), (83, [10])] - assert strmap(int, '[82,83][10-13]') == [(82, [(10, 13)]), (83, [(10, 13)])] + assert strmap(int, "1") == [1] + assert strmap(int, "") == [None] + assert strmap(int, "1-") == [(1, None)] + assert strmap(int, "-") == [(None, None)] + assert strmap(int, "-1") == [-1] + assert strmap(int, "-1", start=1) == [-1] + assert strmap(int, "-", start=1, end=2) == [(1, 2)] + assert strmap(int, "-1:2") == [(-1, 2)] + assert strmap(int, "1,2") == [1, 2] + assert strmap(int, "1,2[0,2-]") == [1, (2, [0, (2, None)])] + assert strmap(int, "1,2-[0,-2]") == [1, ((2, None), [0, -2])] + assert strmap(int, "1,2[0,2]") == [1, (2, [0, 2])] + assert strmap(int, "1,2-3[0,2]") == [1, ((2, 3), [0, 2])] + assert strmap(int, "1,[2,3][0,2]") == [1, (2, [0, 2]), (3, [0, 2])] + assert strmap(int, "[82][10]") == [(82, [10])] + assert strmap(int, "[82,83][10]") == [(82, [10]), (83, [10])] + assert strmap(int, "[82,83][10-13]") == [(82, [(10, 13)]), (83, [(10, 13)])] def test_strmap2(self): with pytest.raises(ValueError): - strmap(int, '1', sep='*') + strmap(int, "1", sep="*") def test_strmap3(self): - sm = partial(strmap, sep='c') - assert sm(int, '1') == [1] - assert sm(int, '1,2') == [1, 2] - assert sm(int, '1,2{0,2}') == [1, (2, [0, 2])] - assert sm(int, '1,2-3{0,2}') == [1, ((2, 3), [0, 2])] - assert sm(int, '1,{2,3}{0,2}') == [1, (2, [0, 2]), (3, [0, 2])] - assert sm(int, '{82}{10}') == [(82, [10])] - assert sm(int, '{82,83}{10}') == [(82, [10]), (83, [10])] - assert sm(int, '{82,83}{10-13}') == [(82, [(10, 13)]), (83, [(10, 13)])] + sm = partial(strmap, sep="c") + assert sm(int, "1") == [1] + assert sm(int, "1,2") == [1, 2] + assert sm(int, "1,2{0,2}") == [1, (2, [0, 2])] + assert sm(int, "1,2-3{0,2}") == [1, ((2, 3), [0, 2])] + assert sm(int, "1,{2,3}{0,2}") == [1, (2, [0, 2]), (3, [0, 2])] + assert sm(int, "{82}{10}") == [(82, [10])] + assert sm(int, "{82,83}{10}") == [(82, [10]), (83, [10])] + assert sm(int, "{82,83}{10-13}") == [(82, [(10, 13)]), (83, [(10, 13)])] def test_strmap4(self): with pytest.raises(ValueError): - strmap(int, '1[oestuh]]') + strmap(int, "1[oestuh]]") def test_strmap5(self): - r = strmap(int, '1-', end=5) + r = strmap(int, "1-", end=5) assert r == [(1, 5)] - r = strmap(int, '1-', start=0, end=5) + r = strmap(int, "1-", start=0, end=5) assert r == [(1, 5)] - r = strmap(int, '-4', start=0, end=5) + r = strmap(int, "-4", start=0, end=5) assert r == [1] - r = strmap(int, '-', start=0, end=5) + r = strmap(int, "-", start=0, end=5) assert r == [(0, 5)] def test_lstranges1(self): - ranges = strmap(int, '1,2-3[0,2]') + ranges = strmap(int, "1,2-3[0,2]") assert lstranges(ranges) == [1, [2, [0, 2]], [3, [0, 2]]] - ranges = strmap(int, '1,2-4[0-4,2],6[1-3],9-10') - assert lstranges(ranges) == [1, - [2, [0, 1, 2, 3, 4, 2]], - [3, [0, 1, 2, 3, 4, 2]], - [4, [0, 1, 2, 3, 4, 2]], - [6, [1, 2, 3]], - 9, 10] - ranges = strmap(int, '1,[2,4][0-4,2],6[1-3],9-10') - assert lstranges(ranges) == [1, - [2, [0, 1, 2, 3, 4, 2]], - [4, [0, 1, 2, 3, 4, 2]], - [6, [1, 2, 3]], - 9, 10] - ranges = strmap(int, '1,[2,4,6-7][0,3-4,2],6[1-3],9-10') - assert lstranges(ranges) == [1, - [2, [0, 3, 4, 2]], - [4, [0, 3, 4, 2]], - [6, [0, 3, 4, 2]], - [7, [0, 3, 4, 2]], - [6, [1, 2, 3]], - 9, 10] - ranges = strmap(int, '[82,83][10-13]') - assert lstranges(ranges) == [[82, [10, 11, 12, 13]], - [83, [10, 11, 12, 13]]] - ranges = strmap(int, ' [82,85][3]') - assert lstranges(ranges) == [[82, [3]], - [85, [3]]] - ranges = strmap(int, '81,[82,85][3]') - assert lstranges(ranges) == [81, - [82, [3]], - [85, [3]]] + ranges = strmap(int, "1,2-4[0-4,2],6[1-3],9-10") + assert lstranges(ranges) == [ + 1, + [2, [0, 1, 2, 3, 4, 2]], + [3, [0, 1, 2, 3, 4, 2]], + [4, [0, 1, 2, 3, 4, 2]], + [6, [1, 2, 3]], + 9, + 10, + ] + ranges = strmap(int, "1,[2,4][0-4,2],6[1-3],9-10") + assert lstranges(ranges) == [ + 1, + [2, [0, 1, 2, 3, 4, 2]], + [4, [0, 1, 2, 3, 4, 2]], + [6, [1, 2, 3]], + 9, + 10, + ] + ranges = strmap(int, "1,[2,4,6-7][0,3-4,2],6[1-3],9-10") + assert lstranges(ranges) == [ + 1, + [2, [0, 3, 4, 2]], + [4, [0, 3, 4, 2]], + [6, [0, 3, 4, 2]], + [7, [0, 3, 4, 2]], + [6, [1, 2, 3]], + 9, + 10, + ] + ranges = strmap(int, "[82,83][10-13]") + assert lstranges(ranges) == [[82, [10, 11, 12, 13]], [83, [10, 11, 12, 13]]] + ranges = strmap(int, " [82,85][3]") + assert lstranges(ranges) == [[82, [3]], [85, [3]]] + ranges = strmap(int, "81,[82,85][3]") + assert lstranges(ranges) == [81, [82, [3]], [85, [3]]] def test_lstranges2(self): - ranges = strmap(int, '1:2:5') + ranges = strmap(int, "1:2:5") assert lstranges(ranges) == [1, 3, 5] - ranges = strmap(int, '1-2-5') + ranges = strmap(int, "1-2-5") assert lstranges(ranges) == [1, 3, 5] def test_fileindex1(self): - fname = 'hello[1]' - assert fileindex('hehlo')[1] is None + fname = "hello[1]" + assert fileindex("hehlo")[1] is None assert fileindex(fname)[1] == 1 - assert fileindex('hehlo[1,2]')[1] == [1, 2] - assert fileindex('hehlo[1-2]')[1] == [1, 2] - assert fileindex('hehlo[1[1],2]')[1] == [[1, [1]], 2] + assert fileindex("hehlo[1,2]")[1] == [1, 2] + assert fileindex("hehlo[1-2]")[1] == [1, 2] + assert fileindex("hehlo[1[1],2]")[1] == [[1, [1]], 2] def test_list2str(self): a = list2str([2, 4, 5, 6]) diff --git a/src/sisl/viz/__init__.py b/src/sisl/viz/__init__.py index 81b391cbc3..ba6abe83d1 100644 --- a/src/sisl/viz/__init__.py +++ b/src/sisl/viz/__init__.py @@ -16,9 +16,12 @@ except Exception: _nprocs = 1 -register_environ_variable("SISL_VIZ_NUM_PROCS", min(1, _nprocs), - description="Maximum number of processors used for parallel plotting", - process=int) +register_environ_variable( + "SISL_VIZ_NUM_PROCS", + min(1, _nprocs), + description="Maximum number of processors used for parallel plotting", + process=int, +) from . import _xarray_accessor from ._plotables import register_plotable diff --git a/src/sisl/viz/_plotables.py b/src/sisl/viz/_plotables.py index 3069acc724..9c9913f136 100644 --- a/src/sisl/viz/_plotables.py +++ b/src/sisl/viz/_plotables.py @@ -16,15 +16,17 @@ class ClassPlotHandler(ClassDispatcher): """Handles all plotting possibilities for a class""" - def __init__(self, cls, *args, inherited_handlers = (), **kwargs): + def __init__(self, cls, *args, inherited_handlers=(), **kwargs): self._cls = cls if not "instance_dispatcher" in kwargs: kwargs["instance_dispatcher"] = ObjectPlotHandler kwargs["type_dispatcher"] = None super().__init__(*args, inherited_handlers=inherited_handlers, **kwargs) - self._dispatchs = ChainMap(self._dispatchs, *[handler._dispatchs for handler in inherited_handlers]) - + self._dispatchs = ChainMap( + self._dispatchs, *[handler._dispatchs for handler in inherited_handlers] + ) + def set_default(self, key: str): """Sets the default plotting function for the class.""" if key not in self._dispatchs: @@ -32,7 +34,6 @@ def set_default(self, key: str): self._default = key - class ObjectPlotHandler(ObjectDispatcher): """Handles all plotting possibilities for an object.""" @@ -49,7 +50,9 @@ def __call__(self, *args, **kwargs): """If the plot handler is called, we will run the default plotting function unless the keyword method has been passed.""" if self._default is None: - raise TypeError(f"No default plotting function has been defined for {self._obj.__class__.__name__}.") + raise TypeError( + f"No default plotting function has been defined for {self._obj.__class__.__name__}." + ) return getattr(self, self._default)(*args, **kwargs) @@ -74,8 +77,12 @@ def create_plot_dispatch(function, name): """ return type( f"Plot{name.capitalize()}Dispatch", - (PlotDispatch, ), - {"_plot": staticmethod(function), "__doc__": function.__doc__, "__signature__": inspect.signature(function)} + (PlotDispatch,), + { + "_plot": staticmethod(function), + "__doc__": function.__doc__, + "__signature__": inspect.signature(function), + }, ) @@ -95,7 +102,7 @@ def _get_plotting_func(plot_cls, setting_key): a function that accepts the object as first argument and then generates the plot. It sends the object to the appropiate setting key. The rest works exactly the same as - calling the plot class. I.e. you can provide all the extra settings/keywords that you want. + calling the plot class. I.e. you can provide all the extra settings/keywords that you want. """ def _plot(obj, *args, **kwargs): @@ -112,12 +119,23 @@ def _plot(obj, *args, **kwargs): # The signature will be the same as the plot class, but without the setting key, which # will be added by the _plot function - _plot.__signature__ = sig.replace(parameters=[p for p in sig.parameters.values() if p.name != setting_key]) + _plot.__signature__ = sig.replace( + parameters=[p for p in sig.parameters.values() if p.name != setting_key] + ) return _plot -def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=None, name=None, default=False, plot_handler_attr='plot', **kwargs): +def register_plotable( + plotable, + plot_cls=None, + setting_key=None, + plotting_func=None, + name=None, + default=False, + plot_handler_attr="plot", + **kwargs, +): """ Makes the sisl.viz module aware of which sisl objects can be plotted and how to do it. @@ -143,7 +161,7 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N E.g.: If name is "nicely", the plotting function will be registered under "obj.plot.nicely()" If not provided, the name of the function will be used - default: boolean, optional + default: boolean, optional whether this way of plotting the class should be the default one. plot_handler_attr: str, optional the attribute where the plot handler is or should be located in the class that you want to register. @@ -163,8 +181,10 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N # If it's the first time that the class is being registered, # let's give the class a plot handler - if not isinstance(plot_handler, ClassPlotHandler) or plot_handler._cls is not plotable: - + if ( + not isinstance(plot_handler, ClassPlotHandler) + or plot_handler._cls is not plotable + ): if isinstance(plot_handler, ClassPlotHandler): inherited_handlers = [plot_handler] else: @@ -174,7 +194,13 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N if not isinstance(plotable, type): plotable = type(plotable) - setattr(plotable, plot_handler_attr, ClassPlotHandler(plotable, plot_handler_attr, inherited_handlers=inherited_handlers)) + setattr( + plotable, + plot_handler_attr, + ClassPlotHandler( + plotable, plot_handler_attr, inherited_handlers=inherited_handlers + ), + ) plot_handler = getattr(plotable, plot_handler_attr) @@ -182,27 +208,35 @@ def register_plotable(plotable, plot_cls=None, setting_key=None, plotting_func=N # Register the function in the plot_handler plot_handler.register(name, plot_dispatch, default=default, **kwargs) + def register_data_source( - data_source_cls, plot_cls, setting_key, name=None, default: Sequence[Type] = [], plot_handler_attr='plot', + data_source_cls, + plot_cls, + setting_key, + name=None, + default: Sequence[Type] = [], + plot_handler_attr="plot", data_source_init_kwargs: dict = {}, - **kwargs + **kwargs, ): - # First register the data source itself register_plotable( - data_source_cls, plot_cls=plot_cls, setting_key=setting_key, - name=name, plot_handler_attr=plot_handler_attr, - **kwargs + data_source_cls, + plot_cls=plot_cls, + setting_key=setting_key, + name=name, + plot_handler_attr=plot_handler_attr, + **kwargs, ) # And then all its entry points plot_cls_params = { - name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) - for name, param in inspect.signature(plot_cls).parameters.items() if name != setting_key + name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) + for name, param in inspect.signature(plot_cls).parameters.items() + if name != setting_key } for plotable, cls_method in data_source_cls.new.dispatcher.registry.items(): - func = cls_method.__get__(None, data_source_cls) signature = inspect.signature(func) @@ -223,27 +257,33 @@ def register_data_source( for param in list(signature.parameters.values())[1:]: if param.kind == param.VAR_KEYWORD: data_var_kwarg = param.name - replaced_data_args[f'data_{param.name}'] = param.name - param = param.replace(name=f'data_{param.name}', kind=param.KEYWORD_ONLY, default={}) + replaced_data_args[f"data_{param.name}"] = param.name + param = param.replace( + name=f"data_{param.name}", kind=param.KEYWORD_ONLY, default={} + ) elif param.name in plot_cls_params: - replaced_data_args[f'data_{param.name}'] = param.name - param = param.replace(name=f'data_{param.name}') + replaced_data_args[f"data_{param.name}"] = param.name + param = param.replace(name=f"data_{param.name}") data_args.append(param.name) new_parameters.append(param) - new_parameters.extend(list(plot_cls_params.values())) - + new_parameters.extend(list(plot_cls_params.values())) + signature = signature.replace(parameters=new_parameters) params_info = { "data_args": data_args, "replaced_data_args": replaced_data_args, "data_var_kwarg": data_var_kwarg, - "plot_var_kwarg": new_parameters[-1].name if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD else None + "plot_var_kwarg": new_parameters[-1].name + if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD + else None, } - def _plot(obj, *args, __params_info=params_info, __signature=signature, **kwargs): + def _plot( + obj, *args, __params_info=params_info, __signature=signature, **kwargs + ): sig = __signature params_info = __params_info @@ -251,38 +291,42 @@ def _plot(obj, *args, __params_info=params_info, __signature=signature, **kwargs try: data_kwargs = {} - for k in params_info['data_args']: + for k in params_info["data_args"]: if k not in bound.arguments: continue - data_key = params_info['replaced_data_args'].get(k, k) - if params_info['data_var_kwarg'] == data_key: + data_key = params_info["replaced_data_args"].get(k, k) + if params_info["data_var_kwarg"] == data_key: data_kwargs.update(bound.arguments[k]) else: data_kwargs[data_key] = bound.arguments.pop(k) except Exception as e: - raise TypeError(f"Error while parsing arguments to create the {data_source_cls.__name__}") - + raise TypeError( + f"Error while parsing arguments to create the {data_source_cls.__name__}" + ) + for k, v in data_source_init_kwargs.items(): if k not in data_kwargs: data_kwargs[k] = v data = data_source_cls.new(obj, *args, **data_kwargs) - plot_kwargs = bound.arguments.pop(params_info['plot_var_kwarg'], {}) - + plot_kwargs = bound.arguments.pop(params_info["plot_var_kwarg"], {}) + return plot_cls(**{setting_key: data, **bound.arguments, **plot_kwargs}) - + _plot.__signature__ = signature doc = f"Read data into {data_source_cls.__name__} and create a {plot_cls.__name__} from it.\n\n" - doc += "This function accepts the arguments for creating both the data source and the plot. The following"\ - " arguments of the data source have been renamed so that they don't clash with the plot arguments:\n" + \ - '\n'.join( f' - {v} -> {k}' for k, v in replaced_data_args.items()) + \ - f"\n\nDocumentation for the {data_source_cls.__name__} creator ({func.__name__})"\ - f"\n=============\n{inspect.cleandoc(func.__doc__) if func.__doc__ is not None else None}"\ - f"\n\nDocumentation for {plot_cls.__name__}:"\ - f"\n=============\n{inspect.cleandoc(plot_cls.__doc__) if plot_cls.__doc__ is not None else None}" + doc += ( + "This function accepts the arguments for creating both the data source and the plot. The following" + " arguments of the data source have been renamed so that they don't clash with the plot arguments:\n" + + "\n".join(f" - {v} -> {k}" for k, v in replaced_data_args.items()) + + f"\n\nDocumentation for the {data_source_cls.__name__} creator ({func.__name__})" + f"\n=============\n{inspect.cleandoc(func.__doc__) if func.__doc__ is not None else None}" + f"\n\nDocumentation for {plot_cls.__name__}:" + f"\n=============\n{inspect.cleandoc(plot_cls.__doc__) if plot_cls.__doc__ is not None else None}" + ) _plot.__doc__ = doc @@ -290,21 +334,35 @@ def _plot(obj, *args, __params_info=params_info, __signature=signature, **kwargs this_default = plotable in default except: this_default = False - + try: register_plotable( - plotable, plot_cls=plot_cls, - plotting_func=_plot, name=name, default=this_default, plot_handler_attr=plot_handler_attr, - **kwargs + plotable, + plot_cls=plot_cls, + plotting_func=_plot, + name=name, + default=this_default, + plot_handler_attr=plot_handler_attr, + **kwargs, ) except TypeError: pass -def register_sile_method(sile_cls, method: str, plot_cls, setting_key, name=None, default=False, plot_handler_attr='plot', **kwargs): +def register_sile_method( + sile_cls, + method: str, + plot_cls, + setting_key, + name=None, + default=False, + plot_handler_attr="plot", + **kwargs, +): plot_cls_params = { - name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) - for name, param in inspect.signature(plot_cls).parameters.items() if name != setting_key + name: param.replace(kind=inspect.Parameter.KEYWORD_ONLY) + for name, param in inspect.signature(plot_cls).parameters.items() + if name != setting_key } func = getattr(sile_cls, method) @@ -318,11 +376,13 @@ def register_sile_method(sile_cls, method: str, plot_cls, setting_key, name=None for param in list(signature.parameters.values())[1:]: if param.kind == param.VAR_KEYWORD: data_var_kwarg = param.name - replaced_data_args[param.name] = f'data_{param.name}' - param = param.replace(name=f'data_{param.name}', kind=param.KEYWORD_ONLY, default={}) + replaced_data_args[param.name] = f"data_{param.name}" + param = param.replace( + name=f"data_{param.name}", kind=param.KEYWORD_ONLY, default={} + ) elif param.name in plot_cls_params: - replaced_data_args[param.name] = f'data_{param.name}' - param = param.replace(name=f'data_{param.name}') + replaced_data_args[param.name] = f"data_{param.name}" + param = param.replace(name=f"data_{param.name}") data_args.append(param.name) new_parameters.append(param) @@ -333,50 +393,60 @@ def register_sile_method(sile_cls, method: str, plot_cls, setting_key, name=None "data_args": data_args, "replaced_data_args": replaced_data_args, "data_var_kwarg": data_var_kwarg, - "plot_var_kwarg": new_parameters[-1].name if len(new_parameters) > 0 and new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD else None - } - + "plot_var_kwarg": new_parameters[-1].name + if len(new_parameters) > 0 + and new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD + else None, + } + signature = signature.replace(parameters=new_parameters) def _plot(obj, *args, **kwargs): - bound = signature.bind_partial(**kwargs) try: data_kwargs = {} - for k in params_info['data_args']: + for k in params_info["data_args"]: if k not in bound.arguments: continue - data_key = params_info['replaced_data_args'].get(k, k) - if params_info['data_var_kwarg'] == data_key: + data_key = params_info["replaced_data_args"].get(k, k) + if params_info["data_var_kwarg"] == data_key: data_kwargs.update(bound.arguments[k]) else: data_kwargs[data_key] = bound.arguments.pop(k) except: - raise TypeError(f"Error while parsing arguments to create the call {method}") + raise TypeError( + f"Error while parsing arguments to create the call {method}" + ) data = func(obj, *args, **data_kwargs) - plot_kwargs = bound.arguments.pop(params_info['plot_var_kwarg'], {}) - + plot_kwargs = bound.arguments.pop(params_info["plot_var_kwarg"], {}) + return plot_cls(**{setting_key: data, **bound.arguments, **plot_kwargs}) - + _plot.__signature__ = signature doc = f"Calls {method} and creates a {plot_cls.__name__} from its output.\n\n" - doc += f"This function accepts the arguments both for calling {method} and creating the plot. The following"\ - f" arguments of {method} have been renamed so that they don't clash with the plot arguments:\n" + \ - '\n'.join( f' - {k} -> {v}' for k, v in replaced_data_args.items()) + \ - f"\n\nDocumentation for {method} "\ - f"\n=============\n{inspect.cleandoc(func.__doc__) if func.__doc__ is not None else None}"\ - f"\n\nDocumentation for {plot_cls.__name__}:"\ - f"\n=============\n{inspect.cleandoc(plot_cls.__doc__) if plot_cls.__doc__ is not None else None}" + doc += ( + f"This function accepts the arguments both for calling {method} and creating the plot. The following" + f" arguments of {method} have been renamed so that they don't clash with the plot arguments:\n" + + "\n".join(f" - {k} -> {v}" for k, v in replaced_data_args.items()) + + f"\n\nDocumentation for {method} " + f"\n=============\n{inspect.cleandoc(func.__doc__) if func.__doc__ is not None else None}" + f"\n\nDocumentation for {plot_cls.__name__}:" + f"\n=============\n{inspect.cleandoc(plot_cls.__doc__) if plot_cls.__doc__ is not None else None}" + ) _plot.__doc__ = doc - + register_plotable( - sile_cls, plot_cls=plot_cls, - plotting_func=_plot, name=name, default=default, plot_handler_attr=plot_handler_attr, - **kwargs - ) \ No newline at end of file + sile_cls, + plot_cls=plot_cls, + plotting_func=_plot, + name=name, + default=default, + plot_handler_attr=plot_handler_attr, + **kwargs, + ) diff --git a/src/sisl/viz/_plotables_register.py b/src/sisl/viz/_plotables_register.py index d66e150971..12d3a49e1e 100644 --- a/src/sisl/viz/_plotables_register.py +++ b/src/sisl/viz/_plotables_register.py @@ -8,6 +8,7 @@ """ import sisl import sisl.io.siesta as siesta + # import sisl.io.tbtrans as tbtrans from sisl.io.sile import BaseSile, get_siles @@ -26,13 +27,22 @@ # Register data sources # ----------------------------------------------------- -# This will automatically register as plotable everything that +# This will automatically register as plotable everything that # the data source can digest register_data_source(PDOSData, PdosPlot, "pdos_data", default=[siesta.pdosSileSiesta]) -register_data_source(BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta]) -register_data_source(BandsData, FatbandsPlot, "bands_data", data_source_init_kwargs={"extra_vars": ("norm2", )}) -register_data_source(EigenstateData, WavefunctionPlot, "eigenstate", default=[sisl.EigenstateElectron]) +register_data_source( + BandsData, BandsPlot, "bands_data", default=[siesta.bandsSileSiesta] +) +register_data_source( + BandsData, + FatbandsPlot, + "bands_data", + data_source_init_kwargs={"extra_vars": ("norm2",)}, +) +register_data_source( + EigenstateData, WavefunctionPlot, "eigenstate", default=[sisl.EigenstateElectron] +) # ----------------------------------------------------- # Register plotable siles @@ -41,22 +51,22 @@ register = register_plotable for GeomSile in get_siles(attrs=["read_geometry"]): - register_sile_method(GeomSile, "read_geometry", GeometryPlot, 'geometry') + register_sile_method(GeomSile, "read_geometry", GeometryPlot, "geometry") for GridSile in get_siles(attrs=["read_grid"]): - register_sile_method(GridSile, "read_grid", GridPlot, 'grid', default=True) + register_sile_method(GridSile, "read_grid", GridPlot, "grid", default=True) # # ----------------------------------------------------- # # Register plotable sisl objects # # ----------------------------------------------------- # # Geometry -register(sisl.Geometry, GeometryPlot, 'geometry', default=True) +register(sisl.Geometry, GeometryPlot, "geometry", default=True) # # Grid -register(sisl.Grid, GridPlot, 'grid', default=True) +register(sisl.Grid, GridPlot, "grid", default=True) # Brilloiun zone -register(sisl.BrillouinZone, SitesPlot, 'sites_obj') +register(sisl.BrillouinZone, SitesPlot, "sites_obj") sisl.BandStructure.plot.set_default("bands") diff --git a/src/sisl/viz/_presets.py b/src/sisl/viz/_presets.py index bbd04afda8..ea10bf44ef 100644 --- a/src/sisl/viz/_presets.py +++ b/src/sisl/viz/_presets.py @@ -6,13 +6,11 @@ __all__ = ["add_presets", "get_preset"] PRESETS = { - "dark": { "layout": {"template": "sisl_dark"}, "bands_color": "#ccc", - "bands_width": 2 + "bands_width": 2, }, - } diff --git a/src/sisl/viz/_single_dispatch.py b/src/sisl/viz/_single_dispatch.py index d74ee76528..59ae2d6b99 100644 --- a/src/sisl/viz/_single_dispatch.py +++ b/src/sisl/viz/_single_dispatch.py @@ -4,12 +4,11 @@ class singledispatchmethod(real_singledispatchmethod): def register(self, cls, method=None): - if hasattr(cls, '__func__'): - setattr(cls, '__annotations__', cls.__func__.__annotations__) + if hasattr(cls, "__func__"): + setattr(cls, "__annotations__", cls.__func__.__annotations__) return self.dispatcher.register(cls, func=method) - + def __get__(self, obj, cls=None): _method = super().__get__(obj, cls) _method.dispatcher = self.dispatcher return _method - \ No newline at end of file diff --git a/src/sisl/viz/_splot.py b/src/sisl/viz/_splot.py index 7bf5ec744b..6cf5f04856 100644 --- a/src/sisl/viz/_splot.py +++ b/src/sisl/viz/_splot.py @@ -31,35 +31,68 @@ def general_arguments(parser): parser: argparse.ArgumentParser the parser to which you want to add the arguments """ - parser.add_argument('--presets', '-p', type=str, nargs="*", required=False, - help=f'The names of the stored presets that you want to use for the settings. Current available presets: {get_avail_presets()}') + parser.add_argument( + "--presets", + "-p", + type=str, + nargs="*", + required=False, + help=f"The names of the stored presets that you want to use for the settings. Current available presets: {get_avail_presets()}", + ) # parser.add_argument('--template', '-t', type=str, required=False, # help=f"""The plotly layout template that you want to use. It is equivalent as passing a template to --layout. # Available templates: {list(plotly.io.templates.keys())}. Default: {plotly.io.templates.default}""") - parser.add_argument('--layout', '-l', type=ast.literal_eval, required=False, - help=f'A dict containing all the layout attributes that you want to pass to the plot.') + parser.add_argument( + "--layout", + "-l", + type=ast.literal_eval, + required=False, + help=f"A dict containing all the layout attributes that you want to pass to the plot.", + ) - parser.add_argument('--save', '-s', type=str, required=False, - help='The path where you want to save the plot. Note that you can add the extension .html to save to html.') + parser.add_argument( + "--save", + "-s", + type=str, + required=False, + help="The path where you want to save the plot. Note that you can add the extension .html to save to html.", + ) - parser.add_argument('--no-show', dest='show', action='store_false', - help="Pass this flag if you don't want the plot to be displayed.") + parser.add_argument( + "--no-show", + dest="show", + action="store_false", + help="Pass this flag if you don't want the plot to be displayed.", + ) - parser.add_argument('--editable', '-e', dest='editable', action='store_true', - help="Display the plot in editable mode, so that you can edit on-site the titles, axis ranges and more." + - " Keep in mind that the changes won't be saved, but you can take a picture of it with the toolbar") + parser.add_argument( + "--editable", + "-e", + dest="editable", + action="store_true", + help="Display the plot in editable mode, so that you can edit on-site the titles, axis ranges and more." + + " Keep in mind that the changes won't be saved, but you can take a picture of it with the toolbar", + ) - parser.add_argument('--drawable', '-d', dest='drawable', action='store_true', - help="Display the plot in drawable mode, which allows you to draw shapes and lines" + - " Keep in mind that the changes won't be saved, but you can take a picture of it with the toolbar") + parser.add_argument( + "--drawable", + "-d", + dest="drawable", + action="store_true", + help="Display the plot in drawable mode, which allows you to draw shapes and lines" + + " Keep in mind that the changes won't be saved, but you can take a picture of it with the toolbar", + ) - parser.add_argument('--shortcuts', '-sh', nargs="*", - help="The shortcuts to apply to the plot after it has been built. " + - "They should be passed as the sequence of keys that need to be pressed to trigger the shortcut"+ - "You can pass as many as you want. If the built plot is an animation, the shortcuts will be applied"+ - "to each plot separately" + parser.add_argument( + "--shortcuts", + "-sh", + nargs="*", + help="The shortcuts to apply to the plot after it has been built. " + + "They should be passed as the sequence of keys that need to be pressed to trigger the shortcut" + + "You can pass as many as you want. If the built plot is an animation, the shortcuts will be applied" + + "to each plot separately", ) @@ -67,19 +100,25 @@ def splot(): """ Command utility for plotting things fast from the terminal. """ - parser = argparse.ArgumentParser(prog='splot', - description="Command utility to plot files fast. This command allows great customability." + - "\n\nOnly you know how you like your plots. Therefore, a nice way to use this command is by " + - "using presets that you stored previously. Note that you can either use sisl's provided presets" + - f" or define your own presets. Sisl is looking for presets under the '{PRESETS_VARIABLE}' variable" + - f" defined in {PRESETS_FILE}. It should be a dict containing all your presets.", + parser = argparse.ArgumentParser( + prog="splot", + description="Command utility to plot files fast. This command allows great customability." + + "\n\nOnly you know how you like your plots. Therefore, a nice way to use this command is by " + + "using presets that you stored previously. Note that you can either use sisl's provided presets" + + f" or define your own presets. Sisl is looking for presets under the '{PRESETS_VARIABLE}' variable" + + f" defined in {PRESETS_FILE}. It should be a dict containing all your presets.", ) # Add default sisl version stuff cmd.add_sisl_version_cite_arg(parser) - parser.add_argument('--files', "-f", type=str, nargs="*", default=[], - help='The files that you want to plot. As many as you want.' + parser.add_argument( + "--files", + "-f", + type=str, + nargs="*", + default=[], + help="The files that you want to plot. As many as you want.", ) # Add some arguments that work for any plot @@ -88,17 +127,22 @@ def splot(): # Add arguments that correspond to the settings of the Plot class for param in Plot._parameters: if param.dtype is not None and not isinstance(param.dtype, str): - parser.add_argument(f'--{param.key}', type=param.parse, required=False, help=getattr(param, "help", "")) + parser.add_argument( + f"--{param.key}", + type=param.parse, + required=False, + help=getattr(param, "help", ""), + ) subparsers = parser.add_subparsers( - help="YOU DON'T NEED TO PASS A PLOT CLASS. You can provide a file (see the -f flag) and sisl will decide for you."+ - " However, if you want to avoid sisl automatic choice, you can use these subcommands to select a"+ - " plot class. By doing so, you will also get access to plot-specific settings. Try splot bands -h, for example."+ - " Note that you can also build your own plots that will be automatically available here." + - f" Sisl is looking to import plots defined in {PLOTS_FILE}."+ - "\n Also note that doing 'splot bands' with any extra arguments will search your current directory "+ - "for *.bands files to plot. The rest of plots will also do this.", - dest="plot_class" + help="YOU DON'T NEED TO PASS A PLOT CLASS. You can provide a file (see the -f flag) and sisl will decide for you." + + " However, if you want to avoid sisl automatic choice, you can use these subcommands to select a" + + " plot class. By doing so, you will also get access to plot-specific settings. Try splot bands -h, for example." + + " Note that you can also build your own plots that will be automatically available here." + + f" Sisl is looking to import plots defined in {PLOTS_FILE}." + + "\n Also note that doing 'splot bands' with any extra arguments will search your current directory " + + "for *.bands files to plot. The rest of plots will also do this.", + dest="plot_class", ) avail_plots = get_plot_classes() @@ -106,19 +150,30 @@ def splot(): # Generate all the subparsers (one for each type of plot) for PlotClass in avail_plots: doc = PlotClass.__doc__ or "" - specific_parser = subparsers.add_parser(PlotClass.suffix(), help=doc.split(".")[0]) + specific_parser = subparsers.add_parser( + PlotClass.suffix(), help=doc.split(".")[0] + ) if hasattr(PlotClass, "_default_animation"): - specific_parser.add_argument('--animated', '-ani', dest="animated", action="store_true", - help=f"If this flag is present, the default animation for {PlotClass.__name__} will be build"+ - " instead of a regular plot" + specific_parser.add_argument( + "--animated", + "-ani", + dest="animated", + action="store_true", + help=f"If this flag is present, the default animation for {PlotClass.__name__} will be build" + + " instead of a regular plot", ) general_arguments(specific_parser) for param in PlotClass._get_class_params()[0]: if param.dtype is not None and not isinstance(param.dtype, str): - specific_parser.add_argument(f'--{param.key}', type=param.parse, required=False, help=getattr(param, "help", "")) + specific_parser.add_argument( + f"--{param.key}", + type=param.parse, + required=False, + help=getattr(param, "help", ""), + ) args = parser.parse_args() @@ -140,10 +195,9 @@ def splot(): settings[param.key] = setting_value # If no settings were provided, we are going to try to guess - if not settings and hasattr(plot_class, '_registered_plotables'): + if not settings and hasattr(plot_class, "_registered_plotables"): siles = find_plotable_siles(depth=0) for SileClass, filepaths in siles.items(): - if SileClass in plot_class._registered_plotables: settings[plot_class._registered_plotables[SileClass]] = filepaths[0] break @@ -181,17 +235,18 @@ def splot(): # Show the plot if it was requested if getattr(args, "show", True): - # Extra configuration that the user requested for the display config = { - 'editable': args.editable, - 'modeBarButtonsToAdd': [ - 'drawline', - 'drawopenpath', - 'drawclosedpath', - 'drawcircle', - 'drawrect', - 'eraseshape' - ] if args.drawable else [] + "editable": args.editable, + "modeBarButtonsToAdd": [ + "drawline", + "drawopenpath", + "drawclosedpath", + "drawcircle", + "drawrect", + "eraseshape", + ] + if args.drawable + else [], } plot.show(config=config) diff --git a/src/sisl/viz/_xarray_accessor.py b/src/sisl/viz/_xarray_accessor.py index f543a0e922..0ef955ea7d 100644 --- a/src/sisl/viz/_xarray_accessor.py +++ b/src/sisl/viz/_xarray_accessor.py @@ -20,17 +20,21 @@ def _method(self, *args, **kwargs): return _method -def plot_xy(*args, backend: str ="plotly", **kwargs): +def plot_xy(*args, backend: str = "plotly", **kwargs): plot_actions = draw_xarray_xy(*args, **kwargs) return get_figure(plot_actions=plot_actions, backend=backend) + sig = inspect.signature(draw_xarray_xy) -plot_xy.__signature__ = sig.replace(parameters=[ - *sig.parameters.values(), - inspect.Parameter("backend", inspect.Parameter.KEYWORD_ONLY, default="plotly") -]) +plot_xy.__signature__ = sig.replace( + parameters=[ + *sig.parameters.values(), + inspect.Parameter("backend", inspect.Parameter.KEYWORD_ONLY, default="plotly"), + ] +) + @xr.register_dataarray_accessor("sisl") class SislAccessorDataArray: @@ -47,6 +51,7 @@ def __init__(self, xarray_obj): plot_xy = wrap_accessor_method(plot_xy) + @xr.register_dataset_accessor("sisl") class SislAccessorDataset: def __init__(self, xarray_obj): @@ -60,4 +65,4 @@ def __init__(self, xarray_obj): reduce_atoms = wrap_accessor_method(reduce_atom_data) - plot_xy = wrap_accessor_method(plot_xy) \ No newline at end of file + plot_xy = wrap_accessor_method(plot_xy) diff --git a/src/sisl/viz/data/bands.py b/src/sisl/viz/data/bands.py index ad261ffb48..b7c9fdd044 100644 --- a/src/sisl/viz/data/bands.py +++ b/src/sisl/viz/data/bands.py @@ -20,23 +20,32 @@ try: import pathos + _do_parallel_calc = True except: _do_parallel_calc = False try: from aiida import orm + Aiida_node = orm.Node AIIDA_AVAILABLE = True except ModuleNotFoundError: - class Aiida_node: pass + + class Aiida_node: + pass + AIIDA_AVAILABLE = False -class BandsData(XarrayData): - def sanity_check(self, - n_spin: Optional[int] = None, nk: Optional[int] = None, nbands: Optional[int] = None, - klabels: Optional[Sequence[str]] = None, kvals: Optional[Sequence[float]] = None +class BandsData(XarrayData): + def sanity_check( + self, + n_spin: Optional[int] = None, + nk: Optional[int] = None, + nbands: Optional[int] = None, + klabels: Optional[Sequence[str]] = None, + kvals: Optional[Sequence[float]] = None, ): """Check that the dataarray satisfies the requirements to be treated as PDOSData.""" super().sanity_check() @@ -44,29 +53,43 @@ def sanity_check(self, array = self._data for k in ("k", "band"): - assert k in array.dims, f"'{k}' dimension missing, existing dimensions: {array.dims}" + assert ( + k in array.dims + ), f"'{k}' dimension missing, existing dimensions: {array.dims}" - spin = array.attrs['spin'] + spin = array.attrs["spin"] assert isinstance(spin, Spin) if n_spin is not None: if n_spin == 1: - assert spin.is_unpolarized, f"Spin in the data is {spin}, but n_spin=1 was expected" + assert ( + spin.is_unpolarized + ), f"Spin in the data is {spin}, but n_spin=1 was expected" elif n_spin == 2: - assert spin.is_polarized, f"Spin in the data is {spin}, but n_spin=2 was expected" + assert ( + spin.is_polarized + ), f"Spin in the data is {spin}, but n_spin=2 was expected" elif n_spin == 4: - assert not spin.is_diagonal, f"Spin in the data is {spin}, but n_spin=4 was expected" + assert ( + not spin.is_diagonal + ), f"Spin in the data is {spin}, but n_spin=4 was expected" # Check if we have the correct number of spin channels if spin.is_polarized: - assert "spin" in array.dims, f"'spin' dimension missing for polarized spin, existing dimensions: {array.dims}" + assert ( + "spin" in array.dims + ), f"'spin' dimension missing for polarized spin, existing dimensions: {array.dims}" if n_spin is not None: assert len(array.spin) == n_spin else: - assert "spin" not in array.dims, f"'spin' dimension present for spin different than polarized, existing dimensions: {array.dims}" - assert "spin" not in array.coords, f"'spin' coordinate present for spin different than polarized, existing dimensions: {array.dims}" - - # Check shape of bands + assert ( + "spin" not in array.dims + ), f"'spin' dimension present for spin different than polarized, existing dimensions: {array.dims}" + assert ( + "spin" not in array.coords + ), f"'spin' coordinate present for spin different than polarized, existing dimensions: {array.dims}" + + # Check shape of bands if nk is not None: assert len(array.k) == nk if nbands is not None: @@ -78,62 +101,84 @@ def sanity_check(self, # Check if k ticks match the expected ones if klabels is not None: assert "axis" in array.k.attrs, "No axis specification for the k dimension." - assert "ticktext" in array.k.attrs['axis'], "No ticks were found for the k dimension" - assert tuple(array.k.attrs['axis']['ticktext']) == tuple(klabels), f"Expected labels {klabels} but found {array.k.attrs['axis']['ticktext']}" + assert ( + "ticktext" in array.k.attrs["axis"] + ), "No ticks were found for the k dimension" + assert tuple(array.k.attrs["axis"]["ticktext"]) == tuple( + klabels + ), f"Expected labels {klabels} but found {array.k.attrs['axis']['ticktext']}" if kvals is not None: assert "axis" in array.k.attrs, "No axis specification for the k dimension." - assert "tickvals" in array.k.attrs['axis'], "No ticks were found for the k dimension" - assert np.allclose(array.k.attrs['axis']['tickvals'], kvals), f"Expected label values {kvals} but found {array.k.attrs['axis']['tickvals']}" + assert ( + "tickvals" in array.k.attrs["axis"] + ), "No ticks were found for the k dimension" + assert np.allclose( + array.k.attrs["axis"]["tickvals"], kvals + ), f"Expected label values {kvals} but found {array.k.attrs['axis']['tickvals']}" @classmethod - def toy_example(cls, spin: Union[str, int, Spin] = "", n_states: int = 20, nk: int = 30, gap: Optional[float] = None): + def toy_example( + cls, + spin: Union[str, int, Spin] = "", + n_states: int = 20, + nk: int = 30, + gap: Optional[float] = None, + ): """Creates a toy example of a bands data array""" spin = Spin(spin) n_bands = n_states if spin.is_diagonal else n_states * 2 - + if spin.is_polarized: polynoms_shape = (2, n_bands) dims = ("spin", "k", "band") shift = np.tile(np.arange(0, n_bands), 2).reshape(2, -1) else: - polynoms_shape = (n_bands, ) + polynoms_shape = (n_bands,) dims = ("k", "band") shift = np.arange(0, n_bands) # Create some random coefficients for degree 2 polynomials that will be used to generate the bands random_polinomials = np.random.rand(*polynoms_shape, 3) - random_polinomials[..., 0] *= 10 # Bigger curvature - random_polinomials[..., :n_bands // 2, 0] *= -1 # Make the curvature negative below the gap - random_polinomials[..., 2] += shift # Shift each polynomial so that bands stack on top of each other + random_polinomials[..., 0] *= 10 # Bigger curvature + random_polinomials[ + ..., : n_bands // 2, 0 + ] *= -1 # Make the curvature negative below the gap + random_polinomials[ + ..., 2 + ] += shift # Shift each polynomial so that bands stack on top of each other # Compute bands x = np.linspace(0, 1, nk) - y = np.outer(x ** 2, random_polinomials[..., 0]) + np.outer(x, random_polinomials[..., 1]) + random_polinomials[..., 2].ravel() + y = ( + np.outer(x**2, random_polinomials[..., 0]) + + np.outer(x, random_polinomials[..., 1]) + + random_polinomials[..., 2].ravel() + ) y = y.reshape(nk, *polynoms_shape) if spin.is_polarized: # Make sure that the top of the valence band and bottom of the conduction band # are the same spin (to facilitate computation of the gap). - VB_spin = y[..., :n_bands // 2].argmin() // (nk * n_bands) - CB_spin = y[..., n_bands // 2:].argmax() // (nk * n_bands) + VB_spin = y[..., : n_bands // 2].argmin() // (nk * n_bands) + CB_spin = y[..., n_bands // 2 :].argmax() // (nk * n_bands) if VB_spin != CB_spin: - y[..., n_bands // 2:] = np.flip(y[..., n_bands // 2:], axis=0) + y[..., n_bands // 2 :] = np.flip(y[..., n_bands // 2 :], axis=0) y = y.transpose(1, 0, 2) # Compute gap limits - top_VB = y[..., :n_bands // 2 ].max() - bottom_CB = y[..., n_bands // 2:].min() + top_VB = y[..., : n_bands // 2].max() + bottom_CB = y[..., n_bands // 2 :].min() # Correct the gap if some specific value was requested generated_gap = bottom_CB - top_VB if gap is not None: - add_shift = (gap - generated_gap) - y[..., n_bands // 2:] += add_shift + add_shift = gap - generated_gap + y[..., n_bands // 2 :] += add_shift bottom_CB += add_shift # Compute fermi level @@ -154,22 +199,18 @@ def toy_example(cls, spin: Union[str, int, Spin] = "", n_states: int = 20, nk: i # Add spin moments if the spin is not diagonal if not spin.is_diagonal: spin_moments = np.random.rand(nk, n_bands, 3) * 2 - 1 - data['spin_moments'] = xr.DataArray( + data["spin_moments"] = xr.DataArray( spin_moments, - coords={ - "k": x, - "band": np.arange(0, n_bands), - "axis": ["x", "y", "z"] - }, - dims=("k", "band", "axis") + coords={"k": x, "band": np.arange(0, n_bands), "axis": ["x", "y", "z"]}, + dims=("k", "band", "axis"), ) # Add the spin class of the data - data.attrs['spin'] = spin + data.attrs["spin"] = spin # Inform of where to place the ticks data.k.attrs["axis"] = { - "tickvals": [0, x[-1]], + "tickvals": [0, x[-1]], "ticktext": ["Gamma", "X"], } @@ -179,19 +220,18 @@ def toy_example(cls, spin: Union[str, int, Spin] = "", n_states: int = 20, nk: i @classmethod def new(cls, bands_data): return cls(bands_data) - + @new.register @classmethod def from_dataset(cls, bands_data: xr.Dataset): - old_attrs = bands_data.attrs - + # Check if there's a spin attribute spin = old_attrs.get("spin", None) # If not, guess it if spin is None: - if 'spin' not in bands_data: + if "spin" not in bands_data: spin = Spin(Spin.UNPOLARIZED) else: spin = { @@ -203,35 +243,39 @@ def from_dataset(cls, bands_data: xr.Dataset): spin = Spin(spin) # Remove the spin coordinate if the data is not spin polarized - if 'spin' in bands_data and not spin.is_polarized: + if "spin" in bands_data and not spin.is_polarized: bands_data = bands_data.isel(spin=0).drop_vars("spin") if spin.is_polarized: spin_options = [0, 1] - bands_data['spin'] = ('spin', spin_options, bands_data.spin.attrs) + bands_data["spin"] = ("spin", spin_options, bands_data.spin.attrs) # elif not spin.is_diagonal: # spin_options = get_spin_options(spin) # bands_data['spin'] = ('spin', spin_options, bands_data.spin.attrs) # If the energy variable doesn't have units, set them as eV - if 'E' in bands_data and 'units' not in bands_data.E.attrs: - bands_data.E.attrs['units'] = 'eV' + if "E" in bands_data and "units" not in bands_data.E.attrs: + bands_data.E.attrs["units"] = "eV" # Same with the k coordinate, which we will assume are 1/Ang - if 'k' in bands_data and 'units' not in bands_data.k.attrs: - bands_data.k.attrs['units'] = '1/Ang' + if "k" in bands_data and "units" not in bands_data.k.attrs: + bands_data.k.attrs["units"] = "1/Ang" # If there are ticks, show the grid. - if 'axis' in bands_data.k.attrs and bands_data.k.attrs['axis'].get('ticktext') is not None: - bands_data.k.attrs['axis'] = {"showgrid": True, **bands_data.k.attrs.get('axis', {})} + if ( + "axis" in bands_data.k.attrs + and bands_data.k.attrs["axis"].get("ticktext") is not None + ): + bands_data.k.attrs["axis"] = { + "showgrid": True, + **bands_data.k.attrs.get("axis", {}), + } - bands_data.attrs = { - **old_attrs, "spin": spin - } + bands_data.attrs = {**old_attrs, "spin": spin} if "geometry" not in bands_data.attrs: if "parent" in bands_data.attrs: parent = bands_data.attrs["parent"] if hasattr(parent, "geometry"): - bands_data.attrs['geometry'] = parent.geometry + bands_data.attrs["geometry"] = parent.geometry return cls(bands_data) @@ -240,9 +284,9 @@ def from_dataset(cls, bands_data: xr.Dataset): def from_dataarray(cls, bands_data: xr.DataArray): bands_data_ds = xr.Dataset({"E": bands_data}) bands_data_ds.attrs.update(bands_data.attrs) - + return cls.new(bands_data_ds) - + @new.register @classmethod def from_path(cls, path: Path, *args, **kwargs): @@ -255,36 +299,40 @@ def from_string(cls, string: str, *args, **kwargs): """Assumes the string is a path to a file""" return cls.new(Path(string), *args, **kwargs) - @new.register @classmethod - def from_fdf(cls, fdf: fdfSileSiesta, bands_file: Union[str, bandsSileSiesta, None] = None): + def from_fdf( + cls, fdf: fdfSileSiesta, bands_file: Union[str, bandsSileSiesta, None] = None + ): """Gets the bands data from a SIESTA .bands file""" - bands_file = FileDataSIESTA(fdf=fdf, path=bands_file, cls=sisl.io.bandsSileSiesta) + bands_file = FileDataSIESTA( + fdf=fdf, path=bands_file, cls=sisl.io.bandsSileSiesta + ) assert isinstance(bands_file, bandsSileSiesta) return cls.new(bands_file) - + @new.register @classmethod def from_siesta_bands(cls, bands_file: bandsSileSiesta): """Gets the bands data from a SIESTA .bands file""" - + bands_data = bands_file.read_data(as_dataarray=True) - bands_data.k.attrs['axis'] = { - 'tickvals': bands_data.attrs.pop('ticks'), - 'ticktext': bands_data.attrs.pop('ticklabels') + bands_data.k.attrs["axis"] = { + "tickvals": bands_data.attrs.pop("ticks"), + "ticktext": bands_data.attrs.pop("ticklabels"), } return cls.new(bands_data) - + @new.register @classmethod - def from_hamiltonian(cls, - bz: sisl.BrillouinZone, - H: Union[sisl.Hamiltonian, None] = None, - extra_vars: Sequence[Union[Dict, str]] = () + def from_hamiltonian( + cls, + bz: sisl.BrillouinZone, + H: Union[sisl.Hamiltonian, None] = None, + extra_vars: Sequence[Union[Dict, str]] = (), ): """Uses a sisl's `BrillouinZone` object to calculate the bands.""" if bz is None: @@ -314,10 +362,9 @@ def from_hamiltonian(cls, # Get a dataset with all values for all spin indices spin_datasets = [] - coords = [var['coords'] for var in all_vars] - name = [var['name'] for var in all_vars] - for spin_index in coords_values['spin']: - + coords = [var["coords"] for var in all_vars] + name = [var["name"] for var in all_vars] + for spin_index in coords_values["spin"]: # Non collinear routines don't accept the keyword argument "spin" spin_kwarg = {"spin": spin_index} if not spin.is_diagonal: @@ -327,7 +374,8 @@ def from_hamiltonian(cls, spin_bands = parallel.dataarray.eigenstate( wrap=partial(bands_wrapper, spin_index=spin_index), **spin_kwarg, - coords=coords, name=name, + coords=coords, + name=name, ) spin_datasets.append(spin_bands) @@ -338,10 +386,11 @@ def from_hamiltonian(cls, # If the band structure contains discontinuities, we will copy the dataset # adding the discontinuities. if isinstance(bz, sisl.BandStructure) and len(bz._jump_idx) > 0: - old_coords = bands_data.coords coords = { - name: bz.insert_jump(old_coords[name]) if name == "k" else old_coords[name].values + name: bz.insert_jump(old_coords[name]) + if name == "k" + else old_coords[name].values for name in old_coords } @@ -354,15 +403,15 @@ def _add_jump(array): bands_data = xr.Dataset( {name: _add_jump(bands_data[name]) for name in bands_data}, - coords=coords + coords=coords, ) - + # Add the spin class of the data - bands_data.attrs['spin'] = spin + bands_data.attrs["spin"] = spin # Inform of where to place the ticks bands_data.k.attrs["axis"] = { - "tickvals": ticks[0], + "tickvals": ticks[0], "ticktext": ticks[1], } @@ -370,7 +419,9 @@ def _add_jump(array): @new.register @classmethod - def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=False): + def from_wfsx( + cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=False + ): """Plots bands from the eigenvalues contained in a WFSX file. It also needs to get a geometry. @@ -378,7 +429,9 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa if need_H: H = HamiltonianDataSource(H=fdf) if H is None: - raise ValueError("Hamiltonian was not setup, and it is needed for the calculations") + raise ValueError( + "Hamiltonian was not setup, and it is needed for the calculations" + ) parent = H geometry = parent.geometry else: @@ -389,14 +442,18 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa parent = geometry # Get the wfsx file - wfsx_sile = FileDataSIESTA(fdf=fdf, path=wfsx_file, cls=sisl.io.wfsxSileSiesta, parent=parent) + wfsx_sile = FileDataSIESTA( + fdf=fdf, path=wfsx_file, cls=sisl.io.wfsxSileSiesta, parent=parent + ) # Now read all the information of the k points from the WFSX file k, weights, nwfs = wfsx_sile.read_info() # Get the number of wavefunctions in the file while performing a quick check nwf = np.unique(nwfs) if len(nwf) > 1: - raise ValueError(f"File {wfsx_sile.file} contains different number of wavefunctions in some k points") + raise ValueError( + f"File {wfsx_sile.file} contains different number of wavefunctions in some k points" + ) nwf = nwf[0] # From the k values read in the file, build a brillouin zone object. # We will use it just to get the linear k values for plotting. @@ -407,10 +464,14 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa nspin, nou, nk, _ = wfsx_sile.read_sizes() # Find out the spin class of the calculation. - spin = Spin({ - 1: Spin.UNPOLARIZED, 2: Spin.POLARIZED, - 4: Spin.NONCOLINEAR, 8: Spin.SPINORBIT - }[nspin]) + spin = Spin( + { + 1: Spin.UNPOLARIZED, + 2: Spin.POLARIZED, + 4: Spin.NONCOLINEAR, + 8: Spin.SPINORBIT, + }[nspin] + ) # Now find out how many spin channels we need. Note that if there is only # one spin channel there will be no "spin" dimension on the final dataset. nspin = 2 if spin.is_polarized else 1 @@ -423,7 +484,9 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa H = sisl.get_sile(fdf).read_hamiltonian() if H is not None: # We could read a hamiltonian, set it as the parent of the wfsx sile - wfsx_sile = FileDataSIESTA(path=wfsx_sile.file, kwargs=dict(parent=parent)) + wfsx_sile = FileDataSIESTA( + path=wfsx_sile.file, kwargs=dict(parent=parent) + ) spin_moments = True except: pass @@ -431,13 +494,15 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa # Get the wrapper function that we should call on each eigenstate. # This also returns the coordinates and names to build the final dataset. bands_wrapper, all_vars, coords_values = _get_eigenstate_wrapper( - sisl.physics.linspace_bz(bz), extra_vars=extra_vars, - spin_moments=spin_moments, spin=spin + sisl.physics.linspace_bz(bz), + extra_vars=extra_vars, + spin_moments=spin_moments, + spin=spin, ) # Make sure all coordinates have values so that we can assume the shape # of arrays below. - coords_values['band'] = np.arange(0, nwf) - coords_values['orb'] = np.arange(0, nou) + coords_values["band"] = np.arange(0, nwf) + coords_values["orb"] = np.arange(0, nou) # Initialize all the arrays. For each quantity we will initialize # an array of the needed shape. @@ -445,12 +510,12 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa for var in all_vars: # These are all the extra dimensions of the quantity. Note that a # quantity does not need to have extra dimensions. - extra_shape = [len(coords_values[coord]) for coord in var['coords']] + extra_shape = [len(coords_values[coord]) for coord in var["coords"]] # First two dimensions will always be the spin channel and the k index. # Then add potential extra dimensions. shape = (nspin, len(bz), *extra_shape) # Initialize the array. - arrays[var['name']] = np.empty(shape, dtype=var.get('dtype', np.float64)) + arrays[var["name"]] = np.empty(shape, dtype=var.get("dtype", np.float64)) # Loop through eigenstates in the WFSX file and add their contribution to the bands ik = -1 @@ -458,24 +523,24 @@ def from_wfsx(cls, wfsx_file: wfsxSileSiesta, fdf: str, extra_vars=(), need_H=Fa i_spin = eigenstate.info.get("spin", 0) # Every time we encounter spin 0, we are in a new k point. if i_spin == 0: - ik +=1 + ik += 1 if ik == 0: # If this is the first eigenstate we read, get the wavefunction # indices. We will assume that ALL EIGENSTATES have the same indices. # Note that we already checked previously that they all have the same # number of wfs, so this is a fair assumption. - coords_values['band'] = eigenstate.info['index'] + coords_values["band"] = eigenstate.info["index"] # Get all the values for this eigenstate. returns = bands_wrapper(eigenstate, spin_index=i_spin) # And store them in the respective arrays. for var, vals in zip(all_vars, returns): - arrays[var['name']][i_spin, ik] = vals + arrays[var["name"]][i_spin, ik] = vals # Now that we have all the values, just build the dataset. bands_data = xr.Dataset( data_vars={ - var['name']: (("spin", "k", *var['coords']), arrays[var['name']]) + var["name"]: (("spin", "k", *var["coords"]), arrays[var["name"]]) for var in all_vars } ).assign_coords(coords_values) @@ -508,7 +573,7 @@ def from_aiida(cls, aiida_bands: Aiida_node): bands, coords={ "spin": np.arange(0, bands.shape[0]), - "k": ('k', plot_data["x"], {"axis": tick_info}), + "k": ("k", plot_data["x"], {"axis": tick_info}), "band": np.arange(0, bands.shape[2]), }, dims=("spin", "k", "band"), @@ -516,7 +581,10 @@ def from_aiida(cls, aiida_bands: Aiida_node): return cls.new(data) -def _get_eigenstate_wrapper(k_vals, spin, extra_vars: Sequence[Union[Dict, str]] = (), spin_moments: bool = True): + +def _get_eigenstate_wrapper( + k_vals, spin, extra_vars: Sequence[Union[Dict, str]] = (), spin_moments: bool = True +): """Helper function to build the function to call on each eigenstate. Parameters @@ -528,7 +596,7 @@ def _get_eigenstate_wrapper(k_vals, spin, extra_vars: Sequence[Union[Dict, str]] This argument determines the extra quantities that should be included in the final dataset of the bands. Energy and spin moments (if available) are already included, so no need to pass them here. - Each item of the array defines a new quantity and should contain a dictionary + Each item of the array defines a new quantity and should contain a dictionary with the following keys: - 'name', str: The name of the quantity. - 'getter', callable: A function that gets 3 arguments: eigenstate, plot and @@ -537,7 +605,7 @@ def _get_eigenstate_wrapper(k_vals, spin, extra_vars: Sequence[Union[Dict, str]] for each (k-point, spin) combination. - 'coords', tuple of str: The names of the dimensions of the returned array. The number of coordinates should match the number of dimensions. - of + of - 'coords_values', dict: If this variable introduces a new coordinate, you should pass the values for that coordinate here. If the coordinates were already defined by another variable, they will already have values. If you are unsure that the @@ -568,10 +636,14 @@ def _get_eigenstate_wrapper(k_vals, spin, extra_vars: Sequence[Union[Dict, str]] spin_indices = [0, 1] # Add a variable to get the eigenvalues. - all_vars = ({ - "coords": ("band",), "coords_values": {"spin": spin_indices, "k": k_vals}, - "name": "E", "getter": lambda eigenstate, spin, spin_index: eigenstate.eig}, - *extra_vars + all_vars = ( + { + "coords": ("band",), + "coords_values": {"spin": spin_indices, "k": k_vals}, + "name": "E", + "getter": lambda eigenstate, spin, spin_index: eigenstate.eig, + }, + *extra_vars, ) # Convert known variable keys to actual variables. @@ -591,8 +663,8 @@ def bands_wrapper(eigenstate, spin_index): return bands_wrapper, all_vars, coords_values -def _norm2_from_eigenstate(eigenstate, spin, spin_index): +def _norm2_from_eigenstate(eigenstate, spin, spin_index): norm2 = eigenstate.norm2(sum=False) if not spin.is_diagonal: @@ -603,17 +675,21 @@ def _norm2_from_eigenstate(eigenstate, spin, spin_index): return norm2.real + def _spin_moment_getter(eigenstate, spin, spin_index): return eigenstate.spin_moment().real + _KNOWN_EIGENSTATE_VARS = { "norm2": { - "coords": ("band", "orb"), - "name": "norm2", - "getter": _norm2_from_eigenstate + "coords": ("band", "orb"), + "name": "norm2", + "getter": _norm2_from_eigenstate, }, "spin_moment": { - "coords": ("axis", "band"), "coords_values": dict(axis=["x", "y", "z"]), - "name": "spin_moments", "getter": _spin_moment_getter - } + "coords": ("axis", "band"), + "coords_values": dict(axis=["x", "y", "z"]), + "name": "spin_moments", + "getter": _spin_moment_getter, + }, } diff --git a/src/sisl/viz/data/data.py b/src/sisl/viz/data/data.py index ddd9d932c2..8522dc696c 100644 --- a/src/sisl/viz/data/data.py +++ b/src/sisl/viz/data/data.py @@ -13,24 +13,24 @@ def __init__(self, data): self._data = data def sanity_check(self): - def is_valid(data, expected_type) -> bool: if expected_type is Any: return True - + return isinstance(data, expected_type) - expected_type = get_type_hints(self.__class__)['_data'] + expected_type = get_type_hints(self.__class__)["_data"] if get_origin(expected_type) is Union: - valid = False for valid_type in get_args(expected_type): valid = valid | is_valid(self._data, valid_type) - + else: valid = is_valid(self._data, expected_type) - assert valid, f"Data must be of type {expected_type} but is {type(self._data).__name__}" + assert ( + valid + ), f"Data must be of type {expected_type} but is {type(self._data).__name__}" def __getattr__(self, key): return getattr(self._data, key) diff --git a/src/sisl/viz/data/eigenstate.py b/src/sisl/viz/data/eigenstate.py index d53e2a4a4d..9340682cff 100644 --- a/src/sisl/viz/data/eigenstate.py +++ b/src/sisl/viz/data/eigenstate.py @@ -11,17 +11,17 @@ class EigenstateData(Data): """Wavefunction data class""" - + @singledispatchmethod @classmethod def new(cls, data): return cls(data) - + @new.register @classmethod def from_eigenstate(cls, eigenstate: sisl.EigenstateElectron): return cls(eigenstate) - + @new.register @classmethod def from_path(cls, path: Path, *args, **kwargs): @@ -33,12 +33,15 @@ def from_path(cls, path: Path, *args, **kwargs): def from_string(cls, string: str, *args, **kwargs): """Assumes the string is a path to a file""" return cls.new(Path(string), *args, **kwargs) - + @new.register @classmethod - def from_fdf(cls, - fdf: fdfSileSiesta, source: Literal["wfsx", "hamiltonian"] = "wfsx", - k: Tuple[float, float, float] = (0, 0, 0), spin: int = 0, + def from_fdf( + cls, + fdf: fdfSileSiesta, + source: Literal["wfsx", "hamiltonian"] = "wfsx", + k: Tuple[float, float, float] = (0, 0, 0), + spin: int = 0, ): if source == "wfsx": sile = FileDataSIESTA(fdf=fdf, cls=wfsxSileSiesta) @@ -55,7 +58,13 @@ def from_fdf(cls, @new.register @classmethod - def from_siesta_wfsx(cls, wfsx_file: wfsxSileSiesta, geometry: sisl.Geometry, k: Tuple[float, float, float] = (0, 0, 0), spin: int = 0): + def from_siesta_wfsx( + cls, + wfsx_file: wfsxSileSiesta, + geometry: sisl.Geometry, + k: Tuple[float, float, float] = (0, 0, 0), + spin: int = 0, + ): """Reads the wavefunction coefficients from a SIESTA WFSX file""" # Get the WFSX file. If not provided, it is inferred from the fdf. if not wfsx_file.file.is_file(): @@ -71,14 +80,19 @@ def from_siesta_wfsx(cls, wfsx_file: wfsxSileSiesta, geometry: sisl.Geometry, k: if eigenstate is None: # We have not found it. raise ValueError(f"A state with k={k} was not found in file {wfsx.file}.") - + return cls.new(eigenstate) @new.register @classmethod - def from_hamiltonian(cls, H: sisl.Hamiltonian, k: Tuple[float, float, float] = (0, 0, 0), spin: int = 0): + def from_hamiltonian( + cls, + H: sisl.Hamiltonian, + k: Tuple[float, float, float] = (0, 0, 0), + spin: int = 0, + ): """Calculates the eigenstates from a Hamiltonian and then generates the wavefunctions.""" return cls.new(H.eigenstate(k, spin=spin)) - + def __getitem__(self, key): - return self._data[key] \ No newline at end of file + return self._data[key] diff --git a/src/sisl/viz/data/pdos.py b/src/sisl/viz/data/pdos.py index 7971cf6902..2aed845093 100644 --- a/src/sisl/viz/data/pdos.py +++ b/src/sisl/viz/data/pdos.py @@ -20,21 +20,26 @@ try: import pathos + _do_parallel_calc = True except: _do_parallel_calc = False + class PDOSData(OrbitalData): """Holds PDOS Data in a custom xarray DataArray. - + The point of this class is to normalize the data coming from different sources - so that functions can use it without worrying where the data came from. + so that functions can use it without worrying where the data came from. """ - - def sanity_check(self, - na: Optional[int] = None, no: Optional[int] = None, n_spin: Optional[int] = None, + + def sanity_check( + self, + na: Optional[int] = None, + no: Optional[int] = None, + n_spin: Optional[int] = None, atom_tags: Optional[Sequence[str]] = None, - dos_checksum: Optional[float] = None + dos_checksum: Optional[float] = None, ): """Check that the dataarray satisfies the requirements to be treated as PDOSData.""" super().sanity_check() @@ -48,10 +53,15 @@ def sanity_check(self, if no is not None: assert geometry.no == no if atom_tags is not None: - assert len(set(atom_tags) - set([atom.tag for atom in geometry.atoms.atom])) == 0 - + assert ( + len(set(atom_tags) - set([atom.tag for atom in geometry.atoms.atom])) + == 0 + ) + for k in ("spin", "orb", "E"): - assert k in array.dims, f"'{k}' dimension missing, existing dimensions: {array.dims}" + assert ( + k in array.dims + ), f"'{k}' dimension missing, existing dimensions: {array.dims}" # Check if we have the correct number of spin channels if n_spin is not None: @@ -62,10 +72,17 @@ def sanity_check(self, # Check if the checksum of the DOS is correct if dos_checksum is not None: this_dos_checksum = float(array.sum()) - assert np.allclose(this_dos_checksum, dos_checksum), f"Checksum of the DOS is incorrect. Expected {dos_checksum} but got {this_dos_checksum}" + assert np.allclose( + this_dos_checksum, dos_checksum + ), f"Checksum of the DOS is incorrect. Expected {dos_checksum} but got {this_dos_checksum}" @classmethod - def toy_example(cls, geometry: Optional[Geometry] = None, spin: Union[str, int, Spin] = "", nE: int = 100): + def toy_example( + cls, + geometry: Optional[Geometry] = None, + spin: Union[str, int, Spin] = "", + nE: int = 100, + ): """Creates a toy example of a bands data array""" if geometry is None: @@ -73,9 +90,9 @@ def toy_example(cls, geometry: Optional[Geometry] = None, spin: Union[str, int, sisl.AtomicOrbital("2s"), sisl.AtomicOrbital("2px"), sisl.AtomicOrbital("2py"), - sisl.AtomicOrbital("2pz") + sisl.AtomicOrbital("2pz"), ] - + geometry = sisl.geom.graphene(atoms=sisl.Atom(Z=6, orbitals=orbitals)) PDOS = np.random.rand(geometry.no, nE) @@ -96,9 +113,14 @@ def new(cls, data: DataArray) -> "PDOSData": @new.register @classmethod - def from_numpy(cls, - PDOS: np.ndarray, geometry: Geometry, E: Sequence[float], E_units: str = 'eV', - spin: Optional[Union[sisl.Spin, str, int]] = None, extra_attrs: dict = {} + def from_numpy( + cls, + PDOS: np.ndarray, + geometry: Geometry, + E: Sequence[float], + E_units: str = "eV", + spin: Optional[Union[sisl.Spin, str, int]] = None, + extra_attrs: dict = {}, ): """ @@ -119,7 +141,7 @@ def from_numpy(cls, E_units: str, optional The units of the energy. Defaults to 'eV'. extra_attrs: dict - A dictionary of extra attributes to be added to the DataArray. One of the attributes that + A dictionary of extra attributes to be added to the DataArray. One of the attributes that """ # Understand what the spin class is for this data. data_spin = sisl.Spin.UNPOLARIZED @@ -127,7 +149,7 @@ def from_numpy(cls, data_spin = { 1: sisl.Spin.UNPOLARIZED, 2: sisl.Spin.POLARIZED, - 4: sisl.Spin.NONCOLINEAR + 4: sisl.Spin.NONCOLINEAR, }[PDOS.shape[0]] data_spin = sisl.Spin(data_spin) @@ -150,29 +172,35 @@ def from_numpy(cls, orb_dim = PDOS.ndim - 2 if geometry is not None: if geometry.no != PDOS.shape[orb_dim]: - raise ValueError(f"The geometry provided contains {geometry.no} orbitals, while we have PDOS information of {PDOS.shape[orb_dim]}.") - + raise ValueError( + f"The geometry provided contains {geometry.no} orbitals, while we have PDOS information of {PDOS.shape[orb_dim]}." + ) + # Build the standardized dataarray, with everything needed to understand it. E_units = extra_attrs.pop("E_units", "eV") if spin.is_polarized: - spin_coords = ['total', 'z'] + spin_coords = ["total", "z"] elif not spin.is_diagonal: spin_coords = get_spin_options(spin) else: spin_coords = ["total"] - coords = [("spin", spin_coords), ("orb", range(PDOS.shape[orb_dim])), ("E", E, {"units": E_units})] - - attrs = {"spin": spin, "geometry": geometry, "units": f"1/{E_units}", **extra_attrs} - - return cls.new(DataArray( - PDOS, - coords=coords, - name="PDOS", - attrs=attrs - )) - + coords = [ + ("spin", spin_coords), + ("orb", range(PDOS.shape[orb_dim])), + ("E", E, {"units": E_units}), + ] + + attrs = { + "spin": spin, + "geometry": geometry, + "units": f"1/{E_units}", + **extra_attrs, + } + + return cls.new(DataArray(PDOS, coords=coords, name="PDOS", attrs=attrs)) + @new.register @classmethod def from_path(cls, path: Path, *args, **kwargs): @@ -187,14 +215,16 @@ def from_string(cls, string: str, *args, **kwargs): @new.register @classmethod - def from_fdf(cls, - fdf: fdfSileSiesta, source: Literal["pdos", "tbtnc", "wfsx", "hamiltonian"] = "pdos", - **kwargs + def from_fdf( + cls, + fdf: fdfSileSiesta, + source: Literal["pdos", "tbtnc", "wfsx", "hamiltonian"] = "pdos", + **kwargs, ): """Gets the PDOS from the fdf file. It uses the fdf file as the pivoting point to find the rest of files needed. - + Parameters ---------- fdf: fdfSileSiesta @@ -203,7 +233,7 @@ def from_fdf(cls, The source to read the PDOS data from. **kwargs Extra arguments to be passed to the PDOSData constructor, which depends - on the source requested. + on the source requested. Except for the hamiltonian source, no extra arguments are needed (and they won't be used). See PDOSData.from_hamiltonian for the extra arguments accepted @@ -244,10 +274,12 @@ def from_siesta_pdos(cls, pdos_file: pdosSileSiesta): geometry, E, PDOS = pdos_file.read_data() return cls.new(PDOS, geometry, E) - + @new.register @classmethod - def from_tbtrans(cls, tbt_nc: tbtncSileTBtrans, geometry: Union[Geometry, None] = None): + def from_tbtrans( + cls, tbt_nc: tbtncSileTBtrans, geometry: Union[Geometry, None] = None + ): """Reads the PDOS from a *.TBT.nc file coming from a TBtrans run.""" PDOS = tbt_nc.DOS(sum=False).T E = tbt_nc.E @@ -260,11 +292,19 @@ def from_tbtrans(cls, tbt_nc: tbtncSileTBtrans, geometry: Union[Geometry, None] geometry = tbt_nc.read_geometry(**read_geometry_kwargs).sub(tbt_nc.a_dev) return cls.new(PDOS, geometry, E) - + @new.register @classmethod - def from_hamiltonian(cls, H: Hamiltonian, kgrid=None, kgrid_displ=(0, 0, 0), Erange=(-2, 2), - E0=0, nE=100, distribution=get_distribution("gaussian")): + def from_hamiltonian( + cls, + H: Hamiltonian, + kgrid=None, + kgrid_displ=(0, 0, 0), + Erange=(-2, 2), + E0=0, + nE=100, + distribution=get_distribution("gaussian"), + ): """Calculates the PDOS from a sisl Hamiltonian.""" # Get the kgrid or generate a default grid by checking the interaction between cells @@ -275,7 +315,9 @@ def from_hamiltonian(cls, H: Hamiltonian, kgrid=None, kgrid_displ=(0, 0, 0), Era Erange = Erange if Erange is None: - raise ValueError('You need to provide an energy range to calculate the PDOS from the Hamiltonian') + raise ValueError( + "You need to provide an energy range to calculate the PDOS from the Hamiltonian" + ) E = np.linspace(Erange[0], Erange[-1], nE) + E0 @@ -291,8 +333,7 @@ def from_hamiltonian(cls, H: Hamiltonian, kgrid=None, kgrid_displ=(0, 0, 0), Era for spin in spin_indices: with bz.apply(pool=_do_parallel_calc) as parallel: spin_PDOS = parallel.average.eigenstate( - spin=spin, - wrap=lambda eig: eig.PDOS(E, distribution=distribution) + spin=spin, wrap=lambda eig: eig.PDOS(E, distribution=distribution) ) PDOS.append(spin_PDOS) @@ -308,30 +349,34 @@ def from_hamiltonian(cls, H: Hamiltonian, kgrid=None, kgrid_displ=(0, 0, 0), Era PDOS = np.array(PDOS) - return cls.new(PDOS, H.geometry, E, spin=H.spin, extra_attrs={'bz': bz}) - + return cls.new(PDOS, H.geometry, E, spin=H.spin, extra_attrs={"bz": bz}) + @new.register @classmethod - def from_wfsx(cls, + def from_wfsx( + cls, wfsx_file: wfsxSileSiesta, - H: Hamiltonian, geometry: Union[Geometry, None] = None, - Erange=(-2, 2), nE: int = 100, E0: float = 0, distribution=get_distribution('gaussian') + H: Hamiltonian, + geometry: Union[Geometry, None] = None, + Erange=(-2, 2), + nE: int = 100, + E0: float = 0, + distribution=get_distribution("gaussian"), ): """Generates the PDOS values from a file containing eigenstates.""" if geometry is None: geometry = getattr(H, "geometry", None) # Get the wfsx file - wfsx_sile = FileDataSIESTA( - path=wfsx_file, cls=sisl.io.wfsxSileSiesta, parent=H - ) + wfsx_sile = FileDataSIESTA(path=wfsx_file, cls=sisl.io.wfsxSileSiesta, parent=H) # Read the sizes of the file, which contain the number of spin channels # and the number of orbitals and the number of k points. sizes = wfsx_sile.read_sizes() # Check that spin sizes of hamiltonian and wfsx file match - assert H.spin.size == sizes.nspin, \ - f"Hamiltonian has spin size {H.spin.size} while file has spin size {sizes.nspin}" + assert ( + H.spin.size == sizes.nspin + ), f"Hamiltonian has spin size {H.spin.size} while file has spin size {sizes.nspin}" # Get the size of the spin channel. The size returned might be 8 if it is a spin-orbit # calculation, but we need only 4 spin channels (total, x, y and z), same as with non-colinear nspin = min(4, sizes.nspin) @@ -351,6 +396,8 @@ def from_wfsx(cls, if nspin == 4: spin = slice(None) - PDOS[spin] += eigenstate.PDOS(E, distribution=distribution) * eigenstate.info.get("weight", 1) + PDOS[spin] += eigenstate.PDOS( + E, distribution=distribution + ) * eigenstate.info.get("weight", 1) return cls.new(PDOS, geometry, E, spin=H.spin) diff --git a/src/sisl/viz/data/sisl_objs.py b/src/sisl/viz/data/sisl_objs.py index dcc4b78126..d8fdc20075 100644 --- a/src/sisl/viz/data/sisl_objs.py +++ b/src/sisl/viz/data/sisl_objs.py @@ -7,22 +7,29 @@ class SislObjData(Data): """Base class for sisl objects""" + def __instancecheck__(self, instance: Any) -> bool: - expected_type = get_type_hints(self.__class__)['_data'] + expected_type = get_type_hints(self.__class__)["_data"] return isinstance(instance, expected_type) - + def __subclasscheck__(self, subclass: Any) -> bool: - expected_type = get_type_hints(self.__class__)['_data'] + expected_type = get_type_hints(self.__class__)["_data"] return issubclass(subclass, expected_type) + class GeometryData(Data): """Geometry data class""" + _data: Geometry + class GridData(Data): """Grid data class""" + _data: Grid + class HamiltonianData(Data): """Hamiltonian data class""" - _data: Hamiltonian \ No newline at end of file + + _data: Hamiltonian diff --git a/src/sisl/viz/data/tests/conftest.py b/src/sisl/viz/data/tests/conftest.py index a752c0a50b..f444814153 100644 --- a/src/sisl/viz/data/tests/conftest.py +++ b/src/sisl/viz/data/tests/conftest.py @@ -5,8 +5,7 @@ @pytest.fixture(scope="session") def siesta_test_files(sisl_files): - def _siesta_test_files(path): - return sisl_files(osp.join('sisl', 'io', 'siesta', path)) + return sisl_files(osp.join("sisl", "io", "siesta", path)) - return _siesta_test_files \ No newline at end of file + return _siesta_test_files diff --git a/src/sisl/viz/data/tests/test_bands.py b/src/sisl/viz/data/tests/test_bands.py index 1802a5ad58..201abe83ac 100644 --- a/src/sisl/viz/data/tests/test_bands.py +++ b/src/sisl/viz/data/tests/test_bands.py @@ -6,7 +6,9 @@ pytestmark = [pytest.mark.viz, pytest.mark.data] -@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) +@pytest.mark.parametrize( + "spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"] +) def test_bands_from_sisl_H(spin): gr = sisl.geom.graphene() H = sisl.Hamiltonian(gr) @@ -16,18 +18,26 @@ def test_bands_from_sisl_H(spin): "unpolarized": (1, H), "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), "noncolinear": (4, H.transform(spin=sisl.Spin.NONCOLINEAR)), - "spinorbit": (4, H.transform(spin=sisl.Spin.SPINORBIT)) + "spinorbit": (4, H.transform(spin=sisl.Spin.SPINORBIT)), }[spin] - bz = sisl.BandStructure(H, [[0, 0, 0], [2/3, 1/3, 0], [1/2, 0, 0]], 6, ["Gamma", "M", "K"]) + bz = sisl.BandStructure( + H, [[0, 0, 0], [2 / 3, 1 / 3, 0], [1 / 2, 0, 0]], 6, ["Gamma", "M", "K"] + ) data = BandsData.new(bz) - data.sanity_check(n_spin=n_spin, nk=6, nbands=2, klabels=["Gamma", "M", "K"], kvals=[0., 1.70309799, 2.55464699]) + data.sanity_check( + n_spin=n_spin, + nk=6, + nbands=2, + klabels=["Gamma", "M", "K"], + kvals=[0.0, 1.70309799, 2.55464699], + ) + @pytest.mark.parametrize("spin", ["unpolarized"]) def test_bands_from_siesta_bands(spin, siesta_test_files): - n_spin, filename = { "unpolarized": (1, "SrTiO3.bands"), }[spin] @@ -36,11 +46,17 @@ def test_bands_from_siesta_bands(spin, siesta_test_files): data = BandsData.new(file) - data.sanity_check(n_spin=n_spin, nk=150, nbands=72, klabels=('Gamma', 'X', 'M', 'Gamma', 'R', 'X'), kvals=[0.0, 0.429132, 0.858265, 1.465149, 2.208428, 2.815313]) + data.sanity_check( + n_spin=n_spin, + nk=150, + nbands=72, + klabels=("Gamma", "X", "M", "Gamma", "R", "X"), + kvals=[0.0, 0.429132, 0.858265, 1.465149, 2.208428, 2.815313], + ) + @pytest.mark.parametrize("spin", ["noncolinear"]) def test_bands_from_siesta_wfsx(spin, siesta_test_files): - n_spin, filename = { "noncolinear": (4, "bi2se3_3ql.bands.WFSX"), }[spin] @@ -52,22 +68,21 @@ def test_bands_from_siesta_wfsx(spin, siesta_test_files): data.sanity_check(n_spin=n_spin, nk=16, nbands=4) -@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) -def test_toy_example(spin): +@pytest.mark.parametrize( + "spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"] +) +def test_toy_example(spin): nk = 15 n_states = 28 data = BandsData.toy_example(spin=spin, nk=nk, n_states=n_states) - n_spin = { - "unpolarized": 1, - "polarized": 2, - "noncolinear": 4, - "spinorbit": 4 - }[spin] + n_spin = {"unpolarized": 1, "polarized": 2, "noncolinear": 4, "spinorbit": 4}[spin] - data.sanity_check(n_spin=n_spin, nk=nk, nbands=n_states, klabels=["Gamma", "X"], kvals=[0, 1]) + data.sanity_check( + n_spin=n_spin, nk=nk, nbands=n_states, klabels=["Gamma", "X"], kvals=[0, 1] + ) if n_spin == 4: assert "spin_moments" in data.data_vars diff --git a/src/sisl/viz/data/tests/test_pdos.py b/src/sisl/viz/data/tests/test_pdos.py index b05f28e18d..43621e9629 100644 --- a/src/sisl/viz/data/tests/test_pdos.py +++ b/src/sisl/viz/data/tests/test_pdos.py @@ -6,7 +6,9 @@ pytestmark = [pytest.mark.viz, pytest.mark.data] -@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) +@pytest.mark.parametrize( + "spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"] +) def test_pdos_from_sisl_H(spin): gr = sisl.geom.graphene() H = sisl.Hamiltonian(gr) @@ -16,7 +18,7 @@ def test_pdos_from_sisl_H(spin): "unpolarized": (1, H), "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), "noncolinear": (4, H.transform(spin=sisl.Spin.NONCOLINEAR)), - "spinorbit": (4, H.transform(spin=sisl.Spin.SPINORBIT)) + "spinorbit": (4, H.transform(spin=sisl.Spin.SPINORBIT)), }[spin] data = PDOSData.new(H, Erange=(-5, 5)) @@ -25,15 +27,17 @@ def test_pdos_from_sisl_H(spin): if n_spin > 1: checksum = checksum * 2 - data.sanity_check(na=2, no=2, n_spin=n_spin, atom_tags=('C',), dos_checksum=checksum) + data.sanity_check( + na=2, no=2, n_spin=n_spin, atom_tags=("C",), dos_checksum=checksum + ) + @pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear"]) def test_pdos_from_siesta_PDOS(spin, siesta_test_files): - n_spin, filename = { "unpolarized": (1, "SrTiO3.PDOS"), "polarized": (2, "SrTiO3_polarized.PDOS"), - "noncolinear": (4, "SrTiO3_noncollinear.PDOS") + "noncolinear": (4, "SrTiO3_noncollinear.PDOS"), }[spin] file = siesta_test_files(filename) @@ -44,11 +48,13 @@ def test_pdos_from_siesta_PDOS(spin, siesta_test_files): if n_spin > 1: checksum = checksum * 2 - data.sanity_check(na=5, no=72, n_spin=n_spin, atom_tags=('Sr', 'Ti', 'O'), dos_checksum=checksum) + data.sanity_check( + na=5, no=72, n_spin=n_spin, atom_tags=("Sr", "Ti", "O"), dos_checksum=checksum + ) + @pytest.mark.parametrize("spin", ["noncolinear"]) def test_pdos_from_siesta_wfsx(spin, siesta_test_files): - n_spin, filename = { "noncolinear": (4, "bi2se3_3ql.bands.WFSX"), }[spin] @@ -69,18 +75,19 @@ def test_pdos_from_siesta_wfsx(spin, siesta_test_files): if n_spin > 1: checksum = checksum * 2 - data.sanity_check(na=15, no=195, n_spin=n_spin, atom_tags=('Bi', 'Se'), dos_checksum=checksum) + data.sanity_check( + na=15, no=195, n_spin=n_spin, atom_tags=("Bi", "Se"), dos_checksum=checksum + ) -@pytest.mark.parametrize("spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"]) -def test_toy_example(spin): +@pytest.mark.parametrize( + "spin", ["unpolarized", "polarized", "noncolinear", "spinorbit"] +) +def test_toy_example(spin): data = PDOSData.toy_example(spin=spin) - n_spin = { - "unpolarized": 1, - "polarized": 2, - "noncolinear": 4, - "spinorbit": 4 - }[spin] + n_spin = {"unpolarized": 1, "polarized": 2, "noncolinear": 4, "spinorbit": 4}[spin] - data.sanity_check(n_spin=n_spin, na=data.geometry.na, no=data.geometry.no, atom_tags=["C"]) + data.sanity_check( + n_spin=n_spin, na=data.geometry.na, no=data.geometry.no, atom_tags=["C"] + ) diff --git a/src/sisl/viz/data/xarray.py b/src/sisl/viz/data/xarray.py index 6a8747d025..78ad61f6e5 100644 --- a/src/sisl/viz/data/xarray.py +++ b/src/sisl/viz/data/xarray.py @@ -6,7 +6,6 @@ class XarrayData(Data): - _data: Union[DataArray, Dataset] def __init__(self, data: Union[DataArray, Dataset]): @@ -17,7 +16,7 @@ def __getattr__(self, key): if hasattr(sisl_accessor, key): return getattr(sisl_accessor, key) - + return getattr(self._data, key) def __dir__(self): @@ -25,4 +24,4 @@ def __dir__(self): class OrbitalData(XarrayData): - pass \ No newline at end of file + pass diff --git a/src/sisl/viz/data_sources/atom_data.py b/src/sisl/viz/data_sources/atom_data.py index 9163db0776..76b7406157 100644 --- a/src/sisl/viz/data_sources/atom_data.py +++ b/src/sisl/viz/data_sources/atom_data.py @@ -9,52 +9,63 @@ class AtomData(DataSource): def function(self, geometry, atoms=None): raise NotImplementedError("") + @AtomData.from_func def AtomCoords(geometry, atoms=None): return geometry.xyz[atoms] + @AtomData.from_func def AtomX(geometry, atoms=None): return geometry.xyz[atoms, 0] + @AtomData.from_func def AtomY(geometry, atoms=None): return geometry.xyz[atoms, 1] + @AtomData.from_func def AtomZ(geometry, atoms=None): return geometry.xyz[atoms, 2] + @AtomData.from_func def AtomFCoords(geometry, atoms=None): return geometry.sub(atoms).fxyz + @AtomData.from_func def AtomFx(geometry, atoms=None): return geometry.sub(atoms).fxyz[:, 0] + @AtomData.from_func def AtomFy(geometry, atoms=None): return geometry.sub(atoms).fxyz[:, 1] + @AtomData.from_func def AtomFz(geometry, atoms=None): return geometry.sub(atoms).fxyz[:, 2] + @AtomData.from_func def AtomR(geometry, atoms=None): return geometry.sub(atoms).maxR(all=True) + @AtomData.from_func def AtomZ(geometry, atoms=None): return geometry.sub(atoms).atoms.Z + @AtomData.from_func def AtomNOrbitals(geometry, atoms=None): return geometry.sub(atoms).orbitals -class AtomDefaultColors(AtomData): +class AtomDefaultColors(AtomData): _atoms_colors = { "H": "#cccccc", "O": "red", @@ -64,25 +75,31 @@ class AtomDefaultColors(AtomData): "S": "yellow", "P": "orange", "Au": "gold", - "else": "pink" + "else": "pink", } def function(self, geometry, atoms=None): - return np.array([ - self._atoms_colors.get(atom.symbol, self._atoms_colors["else"]) - for atom in geometry.sub(atoms).atoms - ]) + return np.array( + [ + self._atoms_colors.get(atom.symbol, self._atoms_colors["else"]) + for atom in geometry.sub(atoms).atoms + ] + ) + @AtomData.from_func def AtomIsGhost(geometry, atoms=None, fill_true=True, fill_false=False): - return np.array([ - fill_true if isinstance(atom, AtomGhost) else fill_false - for atom in geometry.sub(atoms).atoms - ]) + return np.array( + [ + fill_true if isinstance(atom, AtomGhost) else fill_false + for atom in geometry.sub(atoms).atoms + ] + ) + @AtomData.from_func def AtomPeriodicTable(geometry, atoms=None, what=None, pt=PeriodicTable): if not isinstance(pt, PeriodicTable): pt = pt() function = getattr(pt, what) - return function(geometry.sub(atoms).atoms.Z) \ No newline at end of file + return function(geometry.sub(atoms).atoms.Z) diff --git a/src/sisl/viz/data_sources/bond_data.py b/src/sisl/viz/data_sources/bond_data.py index b74d17a7a5..c87ac46e91 100644 --- a/src/sisl/viz/data_sources/bond_data.py +++ b/src/sisl/viz/data_sources/bond_data.py @@ -9,15 +9,15 @@ class BondData(DataSource): - ndim: int - + @staticmethod def function(geometry, bonds): raise NotImplementedError("") pass + def bond_lengths(geometry: sisl.Geometry, bonds: np.ndarray): # Get an array with the coordinates defining the start and end of each bond. # The array will be of shape (nbonds, 2, 3) @@ -27,35 +27,51 @@ def bond_lengths(geometry: sisl.Geometry, bonds: np.ndarray): # Finally, we just ravel it to an array of shape (nbonds, ) return fnorm(np.diff(coords, axis=1), axis=-1).ravel() -def bond_strains(ref_geometry: sisl.Geometry, geometry: sisl.Geometry, bonds: np.ndarray): - assert ref_geometry.na == geometry.na, (f"Geometry provided (na={geometry.na}) does not have the" - f" same number of atoms as the reference geometry (na={ref_geometry.na})") + +def bond_strains( + ref_geometry: sisl.Geometry, geometry: sisl.Geometry, bonds: np.ndarray +): + assert ref_geometry.na == geometry.na, ( + f"Geometry provided (na={geometry.na}) does not have the" + f" same number of atoms as the reference geometry (na={ref_geometry.na})" + ) ref_bond_lengths = bond_lengths(ref_geometry, bonds) this_bond_lengths = bond_lengths(geometry, bonds) return (this_bond_lengths - ref_bond_lengths) / ref_bond_lengths -def bond_data_from_atom(atom_data: np.ndarray, geometry: sisl.Geometry, bonds: np.ndarray, fold_to_uc: bool = False): +def bond_data_from_atom( + atom_data: np.ndarray, + geometry: sisl.Geometry, + bonds: np.ndarray, + fold_to_uc: bool = False, +): if fold_to_uc: bonds = geometry.sc2uc(bonds) return atom_data[bonds[:, 0]] -def bond_data_from_matrix(matrix, geometry: sisl.Geometry, bonds: np.ndarray, fold_to_uc: bool = False): +def bond_data_from_matrix( + matrix, geometry: sisl.Geometry, bonds: np.ndarray, fold_to_uc: bool = False +): if fold_to_uc: bonds = geometry.sc2uc(bonds) return matrix[bonds[:, 0], bonds[:, 1]] -def bond_random(geometry: sisl.Geometry, bonds: np.ndarray, seed: Union[int, None] = None): + +def bond_random( + geometry: sisl.Geometry, bonds: np.ndarray, seed: Union[int, None] = None +): if seed is not None: np.random.seed(seed) return np.random.random(len(bonds)) + BondLength = BondData.from_func(bond_lengths) BondStrain = BondData.from_func(bond_strains) BondDataFromAtom = BondData.from_func(bond_data_from_atom) diff --git a/src/sisl/viz/data_sources/data_source.py b/src/sisl/viz/data_sources/data_source.py index dc49475832..433e354e03 100644 --- a/src/sisl/viz/data_sources/data_source.py +++ b/src/sisl/viz/data_sources/data_source.py @@ -3,18 +3,19 @@ class DataSource(Node): """Generic class for data sources. - + Data sources are a way of specifying and manipulating data without providing it explicitly. Data sources can be passed to the settings of the plots as if they were arrays. When the plot is being created, the data source receives the necessary inputs and is evaluated using - its ``get`` method. - + its ``get`` method. + Therefore, passing a data source is like passing a function that will receive inputs and calculate the values needed on the fly. However, it has some extra functionality. You can perform operations with a data source. These operations will be evaluated lazily, that is, when inputs are provided. That allows for very convenient manipulation of the data. Data sources are also useful for graphical interfaces, where the user is unable to explicitly - pass a function. Some of them are + pass a function. Some of them are """ + pass diff --git a/src/sisl/viz/data_sources/eigenstate_data.py b/src/sisl/viz/data_sources/eigenstate_data.py index 524290c214..38f60614c8 100644 --- a/src/sisl/viz/data_sources/eigenstate_data.py +++ b/src/sisl/viz/data_sources/eigenstate_data.py @@ -8,14 +8,17 @@ class EigenstateData(DataSource): pass -def spin_moments_from_dataset(axis: Literal['x', 'y', 'z'], data: xr.Dataset) -> xr.DataArray: + +def spin_moments_from_dataset( + axis: Literal["x", "y", "z"], data: xr.Dataset +) -> xr.DataArray: if "spin_moments" not in data: raise ValueError("The dataset does not contain spin moments") spin_moms = data.spin_moments.sel(axis=axis) - spin_moms = spin_moms.rename(f'spin_moments_{axis}') + spin_moms = spin_moms.rename(f"spin_moments_{axis}") return spin_moms -class SpinMoment(EigenstateData): - function = staticmethod(spin_moments_from_dataset) \ No newline at end of file +class SpinMoment(EigenstateData): + function = staticmethod(spin_moments_from_dataset) diff --git a/src/sisl/viz/data_sources/file/__init__.py b/src/sisl/viz/data_sources/file/__init__.py index ee984e3f57..71b37863b5 100644 --- a/src/sisl/viz/data_sources/file/__init__.py +++ b/src/sisl/viz/data_sources/file/__init__.py @@ -1,2 +1,2 @@ from .file_source import * -from .siesta import * \ No newline at end of file +from .siesta import * diff --git a/src/sisl/viz/data_sources/file/file_source.py b/src/sisl/viz/data_sources/file/file_source.py index fdb7728938..101dfd99f0 100644 --- a/src/sisl/viz/data_sources/file/file_source.py +++ b/src/sisl/viz/data_sources/file/file_source.py @@ -6,25 +6,26 @@ class FileData(DataSource): - """ Generic data source for reading data from a file. - + """Generic data source for reading data from a file. + The aim of this class is twofold: - Standarize the way data sources read files. - Provide automatic updating features when the read files are updated. """ + def __init__(self, **kwargs): super().__init__(**kwargs) self._files_to_read = [] def follow_file(self, path): self._files_to_read.append(Path(path).resolve()) - + def get_sile(self, path, **kwargs): - """ A wrapper around get_sile so that the reading of the file is registered""" + """A wrapper around get_sile so that the reading of the file is registered""" self.follow_file(path) return sisl.get_sile(path, **kwargs) def function(self, **kwargs): - if isinstance(kwargs.get('path'), sisl.io.BaseSile): - kwargs['path'] = kwargs['path'].file - return self.get_sile(**kwargs) \ No newline at end of file + if isinstance(kwargs.get("path"), sisl.io.BaseSile): + kwargs["path"] = kwargs["path"].file + return self.get_sile(**kwargs) diff --git a/src/sisl/viz/data_sources/file/siesta.py b/src/sisl/viz/data_sources/file/siesta.py index c64972258a..419f74e42d 100644 --- a/src/sisl/viz/data_sources/file/siesta.py +++ b/src/sisl/viz/data_sources/file/siesta.py @@ -6,7 +6,7 @@ def get_sile(path=None, fdf=None, cls=None, **kwargs): - """ Wrapper around FileData.get_sile that infers files from the root fdf + """Wrapper around FileData.get_sile that infers files from the root fdf Parameters ---------- @@ -28,22 +28,27 @@ def get_sile(path=None, fdf=None, cls=None, **kwargs): if cls is None: raise ValueError(f"Either a path or a class must be provided to get_sile") if fdf is None: - raise ValueError(f"We can not look for files of a sile type without a root fdf file.") - + raise ValueError( + f"We can not look for files of a sile type without a root fdf file." + ) + for rule in sisl.get_sile_rules(cls=cls): - filename = fdf.get('SystemLabel', default='siesta') + f'.{rule.suffix}' + filename = fdf.get("SystemLabel", default="siesta") + f".{rule.suffix}" try: path = fdf.dir_file(filename) return get_sile(path=path, **kwargs) except: pass else: - raise FileNotFoundError(f"Tried to find a {cls} from the root fdf ({fdf.file}), " - f"but didn't find any.") + raise FileNotFoundError( + f"Tried to find a {cls} from the root fdf ({fdf.file}), " + f"but didn't find any." + ) return sisl.get_sile(path, **kwargs) + def FileDataSIESTA(path=None, fdf=None, cls=None, **kwargs): if isinstance(path, sisl.io.BaseSile): path = path.file - return get_sile(path=path, fdf=fdf, cls=cls, **kwargs) \ No newline at end of file + return get_sile(path=path, fdf=fdf, cls=cls, **kwargs) diff --git a/src/sisl/viz/data_sources/hamiltonian_source.py b/src/sisl/viz/data_sources/hamiltonian_source.py index a72cf92e41..6b385a0e6e 100644 --- a/src/sisl/viz/data_sources/hamiltonian_source.py +++ b/src/sisl/viz/data_sources/hamiltonian_source.py @@ -7,13 +7,12 @@ class HamiltonianDataSource(DataSource): - def __init__(self, H=None, kwargs={}): super().__init__(H=H, kwargs=kwargs) def get_hamiltonian(self, H, **kwargs): - """ Setup the Hamiltonian object. - + """Setup the Hamiltonian object. + Parameters ---------- H : sisl.Hamiltonian @@ -29,6 +28,6 @@ def get_hamiltonian(self, H, **kwargs): raise ValueError("No hamiltonian found.") return H - + def function(self, H, kwargs): return self.get_hamiltonian(H=H, **kwargs) diff --git a/src/sisl/viz/data_sources/orbital_data.py b/src/sisl/viz/data_sources/orbital_data.py index e4c7b8aa52..f18200ff8e 100644 --- a/src/sisl/viz/data_sources/orbital_data.py +++ b/src/sisl/viz/data_sources/orbital_data.py @@ -8,34 +8,39 @@ from ..plotutils import random_color from .data_source import DataSource -#from ..processors.orbital import reduce_orbital_data, get_orbital_request_sanitizer +# from ..processors.orbital import reduce_orbital_data, get_orbital_request_sanitizer class OrbitalData(DataSource): pass -def style_fatbands(data, groups=[{}]): +def style_fatbands(data, groups=[{}]): # Get the function that is going to convert our request to something that can actually # select orbitals from the xarray object. _sanitize_request = get_orbital_request_sanitizer( - data, + data, gens={ "color": lambda req: req.get("color") or random_color(), - } + }, ) styled = reduce_orbital_data( - data, groups, orb_dim="orb", spin_dim="spin", sanitize_group=_sanitize_request, - group_vars=('color', 'dash'), groups_dim="group", drop_empty=True, + data, + groups, + orb_dim="orb", + spin_dim="spin", + sanitize_group=_sanitize_request, + group_vars=("color", "dash"), + groups_dim="group", + drop_empty=True, spin_reduce=np.sum, ) - return styled#.color + return styled # .color -class FatbandsData(OrbitalData): +class FatbandsData(OrbitalData): function = staticmethod(style_fatbands) pass - diff --git a/src/sisl/viz/figure/__init__.py b/src/sisl/viz/figure/__init__.py index 9bf7931f87..25a2bc7490 100644 --- a/src/sisl/viz/figure/__init__.py +++ b/src/sisl/viz/figure/__init__.py @@ -2,41 +2,51 @@ class NotAvailableFigure(Figure): - _package: str = "" def __init__(self, *args, **kwargs): - raise ModuleNotFoundError(f"{self.__class__.__name__} is not available because {self._package} is not installed.") + raise ModuleNotFoundError( + f"{self.__class__.__name__} is not available because {self._package} is not installed." + ) + try: import plotly except ModuleNotFoundError: + class PlotlyFigure(NotAvailableFigure): _package = "plotly" + else: from .plotly import PlotlyFigure try: import matplotlib except ModuleNotFoundError: + class MatplotlibFigure(NotAvailableFigure): _package = "matplotlib" + else: from .matplotlib import MatplotlibFigure try: import py3Dmol except ModuleNotFoundError: + class Py3DmolFigure(NotAvailableFigure): _package = "py3Dmol" + else: from .py3dmol import Py3DmolFigure try: import bpy except ModuleNotFoundError: + class BlenderFigure(NotAvailableFigure): _package = "blender (bpy)" + else: from .blender import BlenderFigure diff --git a/src/sisl/viz/figure/blender.py b/src/sisl/viz/figure/blender.py index 75a42005aa..7fb13fdffd 100644 --- a/src/sisl/viz/figure/blender.py +++ b/src/sisl/viz/figure/blender.py @@ -20,26 +20,36 @@ def add_line_frame(ani_objects, child_objects, frame): child_objects: CollectionObjects the objects of the Atoms collection in the child plot. frame: int - the frame number to which the keyframe values should be set. + the frame number to which the keyframe values should be set. """ # Loop through all objects in the collections for ani_obj, child_obj in zip(ani_objects, child_objects): # Each curve object has multiple splines - for ani_spline, child_spline in zip(ani_obj.data.splines, child_obj.data.splines): + for ani_spline, child_spline in zip( + ani_obj.data.splines, child_obj.data.splines + ): # And each spline has multiple points - for ani_point, child_point in zip(ani_spline.bezier_points, child_spline.bezier_points): + for ani_point, child_point in zip( + ani_spline.bezier_points, child_spline.bezier_points + ): # Set the position of that point ani_point.co = child_point.co ani_point.keyframe_insert(data_path="co", frame=frame) # Loop through all the materials that the object might have associated - for ani_material, child_material in zip(ani_obj.data.materials, child_obj.data.materials): + for ani_material, child_material in zip( + ani_obj.data.materials, child_obj.data.materials + ): ani_mat_inputs = ani_material.node_tree.nodes["Principled BSDF"].inputs child_mat_inputs = child_material.node_tree.nodes["Principled BSDF"].inputs for input_key in ("Base Color", "Alpha"): - ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value - ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) + ani_mat_inputs[input_key].default_value = child_mat_inputs[ + input_key + ].default_value + ani_mat_inputs[input_key].keyframe_insert( + data_path="default_value", frame=frame + ) def add_atoms_frame(ani_objects, child_objects, frame): @@ -55,7 +65,7 @@ def add_atoms_frame(ani_objects, child_objects, frame): child_objects: CollectionObjects the objects of the Atoms collection in the child plot. frame: int - the frame number to which the keyframe values should be set. + the frame number to which the keyframe values should be set. """ # Loop through all objects in the collections for ani_obj, child_obj in zip(ani_objects, child_objects): @@ -68,12 +78,20 @@ def add_atoms_frame(ani_objects, child_objects, frame): ani_obj.keyframe_insert(data_path="scale", frame=frame) # Set the atom color and opacity - ani_mat_inputs = ani_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs - child_mat_inputs = child_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + ani_mat_inputs = ( + ani_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + ) + child_mat_inputs = ( + child_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + ) for input_key in ("Base Color", "Alpha"): - ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value - ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) + ani_mat_inputs[input_key].default_value = child_mat_inputs[ + input_key + ].default_value + ani_mat_inputs[input_key].keyframe_insert( + data_path="default_value", frame=frame + ) class BlenderFigure(Figure): @@ -87,7 +105,7 @@ class BlenderFigure(Figure): """ # Experimental feature to adjust 2D plottings - #_2D_scale = (1, 1) + # _2D_scale = (1, 1) _animatable_collections = { "Lines": {"add_frame": add_line_frame}, @@ -99,24 +117,24 @@ def _init_figure(self, *args, **kwargs): self._collections = {} def _init_figure_animated(self, interpolated_frames: int = 5, **kwargs): - self._animation_settings = { - "interpolated_frames": interpolated_frames - } + self._animation_settings = {"interpolated_frames": interpolated_frames} return self._init_figure(**kwargs) def _iter_animation(self, plot_actions, interpolated_frames=5): - interpolated_frames = self._animation_settings["interpolated_frames"] - + for i, section_actions in enumerate(plot_actions): frame = i * interpolated_frames sanitized_section_actions = [] for action in section_actions: - action_name = action['method'] + action_name = action["method"] if action_name.startswith("draw_"): - action = {**action, "kwargs": {**action.get("kwargs", {}), "frame": frame}} - + action = { + **action, + "kwargs": {**action.get("kwargs", {}), "frame": frame}, + } + sanitized_section_actions.append(action) yield sanitized_section_actions @@ -125,10 +143,9 @@ def draw_on(self, figure): self._plot.get_figure(backend=self._backend_name, clear_fig=False) def clear(self): - """ Clears the blender scene so that data can be reset""" + """Clears the blender scene so that data can be reset""" for key, collection in self._collections.items(): - self.clear_collection(collection) bpy.data.collections.remove(collection) @@ -156,22 +173,50 @@ def clear_collection(self, collection): for obj in collection.objects: bpy.data.objects.remove(obj, do_unlink=True) - def draw_line(self, x, y, name="", line={}, marker={}, text=None, row=None, col=None, **kwargs): + def draw_line( + self, x, y, name="", line={}, marker={}, text=None, row=None, col=None, **kwargs + ): z = np.full_like(x, 0) # x = self._2D_scale[0] * x # y = self._2D_scale[1] * y - return self.draw_line_3D(x, y, z, name=name, line=line, marker=marker, text=text, row=row, col=col, **kwargs) - - def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs): + return self.draw_line_3D( + x, + y, + z, + name=name, + line=line, + marker=marker, + text=text, + row=row, + col=col, + **kwargs, + ) + + def draw_scatter( + self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs + ): z = np.full_like(x, 0) # x = self._2D_scale[0] * x # y = self._2D_scale[1] * y - return self.draw_scatter_3D(x, y, z, name=name, marker=marker, text=text, row=row, col=col, **kwargs) + return self.draw_scatter_3D( + x, y, z, name=name, marker=marker, text=text, row=row, col=col, **kwargs + ) - def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, **kwargs): + def draw_line_3D( + self, x, y, z, line={}, name="", collection=None, frame=None, **kwargs + ): """Draws a line using a bezier curve.""" if frame is not None: - return self._animate_line_3D(x, y, z, line=line, name=name, collection=collection, frame=frame, **kwargs) + return self._animate_line_3D( + x, + y, + z, + line=line, + name=name, + collection=collection, + frame=frame, + **kwargs, + ) if collection is None: collection = self.get_collection(name) @@ -193,8 +238,8 @@ def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, * # Retrieve the curve from the object curve = curve_obj.data # And modify some attributes to make it look cylindric - curve.dimensions = '3D' - curve.fill_mode = 'FULL' + curve.dimensions = "3D" + curve.fill_mode = "FULL" width = line.get("width") curve.bevel_depth = width if width is not None else 0.1 curve.bevel_resolution = 10 @@ -214,7 +259,7 @@ def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, * # Now loop through all segments using the known breakpoints for start_i, end_i in zip(breakpoint_indices, breakpoint_indices[1:]): # Get the coordinates of the segment - segment_xyz = xyz[start_i+1: end_i] + segment_xyz = xyz[start_i + 1 : end_i] # If there is nothing to draw, go to next segment if len(segment_xyz) == 0: @@ -225,7 +270,7 @@ def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, * # Splines by default have only 1 point, add as many as we need segment.bezier_points.add(len(segment_xyz) - 1) # Assign the coordinates to each point - segment.bezier_points.foreach_set('co', np.ravel(segment_xyz)) + segment.bezier_points.foreach_set("co", np.ravel(segment_xyz)) # We want linear interpolation between points. If we wanted cubic interpolation, # we would set this parameter to 3, for example. @@ -236,13 +281,24 @@ def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, * return self - def _animate_line_3D(self, x, y, z, line={}, name="", collection=None, frame=0, **kwargs): + def _animate_line_3D( + self, x, y, z, line={}, name="", collection=None, frame=0, **kwargs + ): if collection is None: collection = self.get_collection(name) # If this is the first frame, draw the object as usual if frame == 0: - self.draw_line_3D(x, y, z, line=line, name=name, collection=collection, frame=None, **kwargs) + self.draw_line_3D( + x, + y, + z, + line=line, + name=name, + collection=collection, + frame=None, + **kwargs, + ) # Create a collection that we are just going to use to create new objects from which # to copy the properties. @@ -250,38 +306,84 @@ def _animate_line_3D(self, x, y, z, line={}, name="", collection=None, frame=0, temp_collection = self.get_collection(temp_collection_name) self.clear_collection(temp_collection) - self.draw_line_3D(x, y, z, line=line, name=name, collection=temp_collection, frame=None, **kwargs) + self.draw_line_3D( + x, + y, + z, + line=line, + name=name, + collection=temp_collection, + frame=None, + **kwargs, + ) # Loop through all objects in the collections for ani_obj, child_obj in zip(collection.objects, temp_collection.objects): # Each curve object has multiple splines - for ani_spline, child_spline in zip(ani_obj.data.splines, child_obj.data.splines): + for ani_spline, child_spline in zip( + ani_obj.data.splines, child_obj.data.splines + ): # And each spline has multiple points - for ani_point, child_point in zip(ani_spline.bezier_points, child_spline.bezier_points): + for ani_point, child_point in zip( + ani_spline.bezier_points, child_spline.bezier_points + ): # Set the position of that point ani_point.co = child_point.co ani_point.keyframe_insert(data_path="co", frame=frame) # Loop through all the materials that the object might have associated - for ani_material, child_material in zip(ani_obj.data.materials, child_obj.data.materials): + for ani_material, child_material in zip( + ani_obj.data.materials, child_obj.data.materials + ): ani_mat_inputs = ani_material.node_tree.nodes["Principled BSDF"].inputs - child_mat_inputs = child_material.node_tree.nodes["Principled BSDF"].inputs + child_mat_inputs = child_material.node_tree.nodes[ + "Principled BSDF" + ].inputs for input_key in ("Base Color", "Alpha"): - ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value - ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) - + ani_mat_inputs[input_key].default_value = child_mat_inputs[ + input_key + ].default_value + ani_mat_inputs[input_key].keyframe_insert( + data_path="default_value", frame=frame + ) + # Remove the temporal collection self.remove_collection(temp_collection_name) - def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, collection=None, frame=None, **kwargs): + def draw_balls_3D( + self, + x, + y, + z, + name=None, + marker={}, + row=None, + col=None, + collection=None, + frame=None, + **kwargs, + ): if frame is not None: - return self._animate_balls_3D(x, y, z, name=name, marker=marker, row=row, col=col, collection=collection, frame=frame, **kwargs) - + return self._animate_balls_3D( + x, + y, + z, + name=name, + marker=marker, + row=row, + col=col, + collection=collection, + frame=frame, + **kwargs, + ) + if collection is None: collection = self.get_collection(name) - bpy.ops.surface.primitive_nurbs_surface_sphere_add(radius=1, enter_editmode=False, align='WORLD') + bpy.ops.surface.primitive_nurbs_surface_sphere_add( + radius=1, enter_editmode=False, align="WORLD" + ) template_ball = bpy.context.object bpy.context.collection.objects.unlink(template_ball) @@ -292,11 +394,15 @@ def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, colle } for k, v in style.items(): - if (not isinstance(v, (collections.abc.Sequence, np.ndarray))) or isinstance(v, str): + if ( + not isinstance(v, (collections.abc.Sequence, np.ndarray)) + ) or isinstance(v, str): style[k] = itertools.repeat(v) ball = template_ball - for i, (x_i, y_i, z_i, color, opacity, size) in enumerate(zip(x, y, z, style["color"], style["opacity"], style["size"])): + for i, (x_i, y_i, z_i, color, opacity, size) in enumerate( + zip(x, y, z, style["color"], style["opacity"], style["size"]) + ): if i > 0: ball = template_ball.copy() ball.data = template_ball.data.copy() @@ -310,15 +416,38 @@ def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, colle ball.name = f"{name}_{i}" ball.data.name = f"{name}_{i}" - self._color_obj(ball, color, opacity=opacity) - - def _animate_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, collection=None, frame=0, **kwargs): + self._color_obj(ball, color, opacity=opacity) + + def _animate_balls_3D( + self, + x, + y, + z, + name=None, + marker={}, + row=None, + col=None, + collection=None, + frame=0, + **kwargs, + ): if collection is None: collection = self.get_collection(name) # If this is the first frame, draw the object as usual if frame == 0: - self.draw_balls_3D(x, y, z, marker=marker, name=name, row=row, col=col, collection=collection, frame=None, **kwargs) + self.draw_balls_3D( + x, + y, + z, + marker=marker, + name=name, + row=row, + col=col, + collection=collection, + frame=None, + **kwargs, + ) # Create a collection that we are just going to use to create new objects from which # to copy the properties. @@ -326,7 +455,18 @@ def _animate_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, c temp_collection = self.get_collection(temp_collection_name) self.clear_collection(temp_collection) - self.draw_balls_3D(x, y, z, marker=marker, name=name, row=row, col=col, collection=temp_collection, frame=None, **kwargs) + self.draw_balls_3D( + x, + y, + z, + marker=marker, + name=name, + row=row, + col=col, + collection=temp_collection, + frame=None, + **kwargs, + ) # Loop through all objects in the collections for ani_obj, child_obj in zip(collection.objects, temp_collection.objects): @@ -339,18 +479,36 @@ def _animate_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, c ani_obj.keyframe_insert(data_path="scale", frame=frame) # Set the atom color and opacity - ani_mat_inputs = ani_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs - child_mat_inputs = child_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + ani_mat_inputs = ( + ani_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + ) + child_mat_inputs = ( + child_obj.data.materials[0].node_tree.nodes["Principled BSDF"].inputs + ) for input_key in ("Base Color", "Alpha"): - ani_mat_inputs[input_key].default_value = child_mat_inputs[input_key].default_value - ani_mat_inputs[input_key].keyframe_insert(data_path="default_value", frame=frame) + ani_mat_inputs[input_key].default_value = child_mat_inputs[ + input_key + ].default_value + ani_mat_inputs[input_key].keyframe_insert( + data_path="default_value", frame=frame + ) self.remove_collection(temp_collection_name) draw_scatter_3D = draw_balls_3D - def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name="Mesh", row=None, col=None, **kwargs): + def draw_mesh_3D( + self, + vertices, + faces, + color=None, + opacity=None, + name="Mesh", + row=None, + col=None, + **kwargs, + ): col = self.get_collection(name) mesh = bpy.data.meshes.new(name) @@ -366,21 +524,21 @@ def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name="Mesh", r @staticmethod def _to_rgb_color(color): - if isinstance(color, str): try: import matplotlib.colors color = matplotlib.colors.to_rgb(color) except ModuleNotFoundError: - raise ValueError("Blender does not understand string colors."+ - "Please provide the color in rgb (tuple of length 3, values from 0 to 1) or install matplotlib so that we can convert it." + raise ValueError( + "Blender does not understand string colors." + + "Please provide the color in rgb (tuple of length 3, values from 0 to 1) or install matplotlib so that we can convert it." ) return color @classmethod - def _color_obj(cls, obj, color, opacity=1.): + def _color_obj(cls, obj, color, opacity=1.0): """Utiity method to quickly color a given object. Parameters @@ -394,7 +552,7 @@ def _color_obj(cls, obj, color, opacity=1.): work currently. """ if opacity is None: - opacity = 1. + opacity = 1.0 color = cls._to_rgb_color(color) @@ -518,4 +676,4 @@ def show(self, *args, **kwargs): # collection = self.get_collection("Unit cell") # super()._draw_cell_3D_box(*args, width=width, collection=collection, **kwargs) -# GeometryPlot.backends.register("blender", BlenderGeometryBackend) \ No newline at end of file +# GeometryPlot.backends.register("blender", BlenderGeometryBackend) diff --git a/src/sisl/viz/figure/figure.py b/src/sisl/viz/figure/figure.py index 23ec368b89..53b3542d3a 100644 --- a/src/sisl/viz/figure/figure.py +++ b/src/sisl/viz/figure/figure.py @@ -8,6 +8,7 @@ BACKENDS = {} + class Figure: """Base figure class that all backends should inherit from. @@ -20,6 +21,7 @@ class Figure: To create a new backend, one might take the PlotlyFigure as a template. """ + _coloraxes: dict = {} _multi_axes: dict = {} @@ -40,54 +42,74 @@ def __init__(self, plot_actions, *args, **kwargs): self._build(plot_actions, *args, **kwargs) def _build(self, plot_actions, *args, **kwargs): - plot_actions = self._sanitize_plot_actions(plot_actions) self._coloraxes = {} self._multi_axes = {} fig = self.init_figure( - composite_method=plot_actions['composite_method'], - plot_actions=plot_actions['plot_actions'], - init_kwargs=plot_actions['init_kwargs'], + composite_method=plot_actions["composite_method"], + plot_actions=plot_actions["plot_actions"], + init_kwargs=plot_actions["init_kwargs"], ) - for section_actions in self._composite_iter(self._composite_mode, plot_actions['plot_actions']): + for section_actions in self._composite_iter( + self._composite_mode, plot_actions["plot_actions"] + ): for action in section_actions: - getattr(self, action['method'])(*action.get('args', ()), **action.get('kwargs', {})) - + getattr(self, action["method"])( + *action.get("args", ()), **action.get("kwargs", {}) + ) + return fig @staticmethod def _sanitize_plot_actions(plot_actions): - def _flatten(plot_actions, out, level=0, root_i=0): for i, section_actions in enumerate(plot_actions): if level == 0: out.append([]) root_i = i - + if isinstance(section_actions, dict): - _flatten(section_actions['plot_actions'], out, level + 1, root_i=root_i) + _flatten( + section_actions["plot_actions"], out, level + 1, root_i=root_i + ) else: # If it's a plot object, we need to extract the plot_actions out[root_i].extend(section_actions) if isinstance(plot_actions, dict): - composite_method = plot_actions.get('composite_method') - init_kwargs = plot_actions.get('init_kwargs', {}) + composite_method = plot_actions.get("composite_method") + init_kwargs = plot_actions.get("init_kwargs", {}) out = [] - _flatten(plot_actions['plot_actions'], out) + _flatten(plot_actions["plot_actions"], out) plot_actions = out else: composite_method = None plot_actions = [plot_actions] init_kwargs = {} - return {"composite_method": composite_method, "plot_actions": plot_actions, "init_kwargs": init_kwargs} + return { + "composite_method": composite_method, + "plot_actions": plot_actions, + "init_kwargs": init_kwargs, + } - def init_figure(self, composite_method: Literal[None, "same_axes", "multiple", "multiple_x", "multiple_y", "subplots", "animation"] = None, - plot_actions=(), init_kwargs: Dict[str, Any] = {}): + def init_figure( + self, + composite_method: Literal[ + None, + "same_axes", + "multiple", + "multiple_x", + "multiple_y", + "subplots", + "animation", + ] = None, + plot_actions=(), + init_kwargs: Dict[str, Any] = {}, + ): if composite_method is None: self._composite_mode = self._NONE return self._init_figure(**init_kwargs) @@ -97,61 +119,82 @@ def init_figure(self, composite_method: Literal[None, "same_axes", "multiple", " elif composite_method.startswith("multiple"): # This could be multiple self._composite_mode = self._MULTIAXIS - multi_axes = [ax for ax in 'xy' if ax in composite_method[8:]] - return self._init_figure_multiple_axes(multi_axes, plot_actions, **init_kwargs) + multi_axes = [ax for ax in "xy" if ax in composite_method[8:]] + return self._init_figure_multiple_axes( + multi_axes, plot_actions, **init_kwargs + ) elif composite_method == "animation": self._composite_mode = self._ANIMATION return self._init_figure_animated(n=len(plot_actions), **init_kwargs) elif composite_method == "subplots": self._composite_mode = self._SUBPLOTS self._rows, self._cols = self._subplots_rows_and_cols( - len(plot_actions), rows=init_kwargs.get('rows'), cols=init_kwargs.get('cols'), - arrange=init_kwargs.pop('arrange', "rows"), + len(plot_actions), + rows=init_kwargs.get("rows"), + cols=init_kwargs.get("cols"), + arrange=init_kwargs.pop("arrange", "rows"), + ) + init_kwargs = ChainMap( + {"rows": self._rows, "cols": self._cols}, init_kwargs ) - init_kwargs = ChainMap({'rows': self._rows, 'cols': self._cols}, init_kwargs) return self._init_figure_subplots(**init_kwargs) else: raise ValueError(f"Unknown composite method '{composite_method}'") def _init_figure(self, **kwargs): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_figure method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a _init_figure method." + ) def _init_figure_same_axes(self, *args, **kwargs): return self._init_figure(*args, **kwargs) - + def _init_figure_multiple_axes(self, multi_axes, plot_actions, **kwargs): figure = self._init_figure() if len(multi_axes) > 2: - raise ValueError(f"{self.__class__.__name__} doesn't support more than one multiple axes.") + raise ValueError( + f"{self.__class__.__name__} doesn't support more than one multiple axes." + ) for axis in multi_axes: self._multi_axes[axis] = self._init_multiaxis(axis, len(plot_actions)) return figure - + def _init_multiaxis(self, axis, n): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_multiaxis method.") - + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a _init_multiaxis method." + ) + def _init_figure_animated(self, **kwargs): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_figure_animated method.") - + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a _init_figure_animated method." + ) + def _init_figure_subplots(self, rows, cols, **kwargs): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _init_figure_subplots method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a _init_figure_subplots method." + ) - def _subplots_rows_and_cols(self, n: int, rows: Optional[int] = None, cols: Optional[int] = None, - arrange: Literal["rows", "cols", "square"] = "rows") -> Tuple[int, int]: - """ Returns the number of rows and columns for a subplot grid. """ + def _subplots_rows_and_cols( + self, + n: int, + rows: Optional[int] = None, + cols: Optional[int] = None, + arrange: Literal["rows", "cols", "square"] = "rows", + ) -> Tuple[int, int]: + """Returns the number of rows and columns for a subplot grid.""" if rows is None and cols is None: - if arrange == 'rows': + if arrange == "rows": rows = n cols = 1 - elif arrange == 'cols': + elif arrange == "cols": cols = n rows = 1 - elif arrange == 'square': - cols = n ** 0.5 - rows = n ** 0.5 + elif arrange == "square": + cols = n**0.5 + rows = n**0.5 # we will correct so it *fits*, always have more columns rows, cols = int(rows), int(cols) cols = n // rows + min(1, n % rows) @@ -165,10 +208,12 @@ def _subplots_rows_and_cols(self, n: int, rows: Optional[int] = None, cols: Opti rows, cols = int(rows), int(cols) if cols * rows < n: - warn(f"requested {n} subplots on a {rows}x{cols} grid layout. {n - cols*rows} plots will be missing.") + warn( + f"requested {n} subplots on a {rows}x{cols} grid layout. {n - cols*rows} plots will be missing." + ) return rows, cols - + def _composite_iter(self, mode, plot_actions): if mode == self._NONE: return plot_actions @@ -187,13 +232,19 @@ def _iter_same_axes(self, plot_actions): return plot_actions def _iter_multiaxis(self, plot_actions): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _iter_multiaxis method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a _iter_multiaxis method." + ) def _iter_subplots(self, plot_actions): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _iter_subplots method.") - + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a _iter_subplots method." + ) + def _iter_animation(self, plot_actions): - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a _iter_animation method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a _iter_animation method." + ) def clear(self): """Clears the figure so that we can draw again.""" @@ -206,17 +257,30 @@ def init_3D(self): """Called if functions that draw in 3D are going to be called.""" return - def init_coloraxis(self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs): + def init_coloraxis( + self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs + ): """Initializes a color axis to be used by the drawing functions""" self._coloraxes[name] = { - 'cmin': cmin, - 'cmax': cmax, - 'cmid': cmid, - 'colorscale': colorscale, - **kwargs + "cmin": cmin, + "cmax": cmax, + "cmid": cmid, + "colorscale": colorscale, + **kwargs, } - def draw_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, **kwargs): + def draw_line( + self, + x, + y, + name=None, + line={}, + marker={}, + text=None, + row=None, + col=None, + **kwargs, + ): """Draws a line satisfying the specifications Parameters @@ -245,31 +309,44 @@ def draw_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, co should allow other keyword arguments to be passed directly to the creation of the line. This will of course be framework specific """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_line method." + ) def draw_multicolor_line(self, *args, line={}, row=None, col=None, **kwargs): """By default, multicoloured lines are drawn simply by drawing scatter points.""" marker = { - **kwargs.pop('marker', {}), - 'color': line.get('color'), - 'size': line.get('width'), - 'opacity': line.get('opacity'), - 'coloraxis': line.get('coloraxis') + **kwargs.pop("marker", {}), + "color": line.get("color"), + "size": line.get("width"), + "opacity": line.get("opacity"), + "coloraxis": line.get("coloraxis"), } self.draw_multicolor_scatter(*args, marker=marker, row=row, col=col, **kwargs) def draw_multisize_line(self, *args, line={}, row=None, col=None, **kwargs): """By default, multisized lines are drawn simple by drawing scatter points.""" marker = { - **kwargs.pop('marker', {}), - 'color': line.get('color'), - 'size': line.get('width'), - 'opacity': line.get('opacity'), - 'coloraxis': line.get('coloraxis') + **kwargs.pop("marker", {}), + "color": line.get("color"), + "size": line.get("width"), + "opacity": line.get("opacity"), + "coloraxis": line.get("coloraxis"), } self.draw_multisize_scatter(*args, marker=marker, row=row, col=col, **kwargs) - def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + def draw_area_line( + self, + x, + y, + name=None, + line={}, + text=None, + dependent_axis=None, + row=None, + col=None, + **kwargs, + ): """Same as draw line, but to draw a line with an area. This is for example used to draw fatbands. Parameters @@ -299,9 +376,22 @@ def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=Non should allow other keyword arguments to be passed directly to the creation of the scatter. This will of course be framework specific """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_area_line method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_area_line method." + ) - def draw_multicolor_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + def draw_multicolor_area_line( + self, + x, + y, + name=None, + line={}, + text=None, + dependent_axis=None, + row=None, + col=None, + **kwargs, + ): """Draw a line with an area with multiple colours. Parameters @@ -331,9 +421,22 @@ def draw_multicolor_area_line(self, x, y, name=None, line={}, text=None, depende should allow other keyword arguments to be passed directly to the creation of the scatter. This will of course be framework specific """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_multicolor_area_line method.") - - def draw_multisize_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_multicolor_area_line method." + ) + + def draw_multisize_area_line( + self, + x, + y, + name=None, + line={}, + text=None, + dependent_axis=None, + row=None, + col=None, + **kwargs, + ): """Draw a line with an area with multiple colours. This is already usually supported by the normal draw_area_line. @@ -366,9 +469,21 @@ def draw_multisize_area_line(self, x, y, name=None, line={}, text=None, dependen the scatter. This will of course be framework specific """ # Usually, multisized area lines are already supported. - return self.draw_area_line(x, y, name=name, line=line, text=text, dependent_axis=dependent_axis, row=row, col=col, **kwargs) + return self.draw_area_line( + x, + y, + name=name, + line=line, + text=text, + dependent_axis=dependent_axis, + row=row, + col=col, + **kwargs, + ) - def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs): + def draw_scatter( + self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs + ): """Draws a scatter satisfying the specifications Parameters @@ -395,25 +510,39 @@ def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None should allow other keyword arguments to be passed directly to the creation of the scatter. This will of course be framework specific """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_scatter method." + ) def draw_multicolor_scatter(self, *args, **kwargs): """Draws a multicoloured scatter. - + Usually the normal scatter can already support this. """ # Usually, multicoloured scatter plots are already supported. return self.draw_scatter(*args, **kwargs) - + def draw_multisize_scatter(self, *args, **kwargs): """Draws a multisized scatter. - + Usually the normal scatter can already support this. """ # Usually, multisized scatter plots are already supported. return self.draw_scatter(*args, **kwargs) - def draw_arrows(self, x, y, dxy, arrowhead_scale=0.2, arrowhead_angle=20, scale: float = 1, annotate: bool = False, row=None, col=None, **kwargs): + def draw_arrows( + self, + x, + y, + dxy, + arrowhead_scale=0.2, + arrowhead_angle=20, + scale: float = 1, + annotate: bool = False, + row=None, + col=None, + **kwargs, + ): """Draws multiple arrows using the generic draw_line method. Parameters @@ -447,17 +576,22 @@ def draw_arrows(self, x, y, dxy, arrowhead_scale=0.2, arrowhead_angle=20, scale: arrowhead_angle = np.radians(arrowhead_angle) # Get the rotation matrices to get the tips of the arrowheads - rot_matrix = np.array([[np.cos(arrowhead_angle), -np.sin(arrowhead_angle)], [np.sin(arrowhead_angle), np.cos(arrowhead_angle)]]) + rot_matrix = np.array( + [ + [np.cos(arrowhead_angle), -np.sin(arrowhead_angle)], + [np.sin(arrowhead_angle), np.cos(arrowhead_angle)], + ] + ) inv_rot = np.linalg.inv(rot_matrix) # Calculate the tips of the arrow heads - arrowhead_tips1 = final_xy - (dxy*arrowhead_scale).dot(rot_matrix) - arrowhead_tips2 = final_xy - (dxy*arrowhead_scale).dot(inv_rot) + arrowhead_tips1 = final_xy - (dxy * arrowhead_scale).dot(rot_matrix) + arrowhead_tips2 = final_xy - (dxy * arrowhead_scale).dot(inv_rot) # Now build an array with all the information to draw the arrows # This has shape (n_arrows * 7, 2). The information to draw an arrow # occupies 7 rows and the columns are the x and y coordinates. - arrows = np.empty((xy.shape[0]*7, xy.shape[1]), dtype=np.float64) + arrows = np.empty((xy.shape[0] * 7, xy.shape[1]), dtype=np.float64) arrows[0::7] = xy arrows[1::7] = final_xy @@ -474,11 +608,30 @@ def draw_arrows(self, x, y, dxy, arrowhead_scale=0.2, arrowhead_angle=20, scale: # Add text annotations just at the tip of the arrows. annotate_text = np.full((arrows.shape[0],), "", dtype=object) annotate_text[4::7] = [str(xy / scale) for xy in dxy] - kwargs['text'] = list(annotate_text) - - return self.draw_line(arrows[:, 0], arrows[:, 1], hovertext=list(hovertext), row=row, col=col, **kwargs) + kwargs["text"] = list(annotate_text) + + return self.draw_line( + arrows[:, 0], + arrows[:, 1], + hovertext=list(hovertext), + row=row, + col=col, + **kwargs, + ) - def draw_line_3D(self, x, y, z, name=None, line={}, marker={}, text=None, row=None, col=None, **kwargs): + def draw_line_3D( + self, + x, + y, + z, + name=None, + line={}, + marker={}, + text=None, + row=None, + col=None, + **kwargs, + ): """Draws a 3D line satisfying the specifications. Parameters @@ -509,7 +662,9 @@ def draw_line_3D(self, x, y, z, name=None, line={}, marker={}, text=None, row=No should allow other keyword arguments to be passed directly to the creation of the line. This will of course be framework specific """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_line_3D method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_line_3D method." + ) def draw_multicolor_line_3D(self, *args, **kwargs): """Draws a multicoloured 3D line.""" @@ -519,7 +674,9 @@ def draw_multisize_line_3D(self, *args, **kwargs): """Draws a multisized 3D line.""" self.draw_line_3D(*args, **kwargs) - def draw_scatter_3D(self, x, y, z, name=None, marker={}, text=None, row=None, col=None, **kwargs): + def draw_scatter_3D( + self, x, y, z, name=None, marker={}, text=None, row=None, col=None, **kwargs + ): """Draws a 3D scatter satisfying the specifications Parameters @@ -547,52 +704,76 @@ def draw_scatter_3D(self, x, y, z, name=None, marker={}, text=None, row=None, co should allow other keyword arguments to be passed directly to the creation of the scatter. This will of course be framework specific """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_scatter_3D method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_scatter_3D method." + ) def draw_multicolor_scatter_3D(self, *args, **kwargs): """Draws a multicoloured 3D scatter. - + Usually the normal 3D scatter can already support this. """ # Usually, multicoloured scatter plots are already supported. return self.draw_scatter_3D(*args, **kwargs) - + def draw_multisize_scatter_3D(self, *args, **kwargs): """Draws a multisized 3D scatter. - + Usually the normal 3D scatter can already support this. """ # Usually, multisized scatter plots are already supported. return self.draw_scatter_3D(*args, **kwargs) - def draw_balls_3D(self, x, y, z, name=None, markers={}, row=None, col=None, **kwargs): + def draw_balls_3D( + self, x, y, z, name=None, markers={}, row=None, col=None, **kwargs + ): """Draws points as 3D spheres.""" - return NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_balls_3D method.") + return NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_balls_3D method." + ) - def draw_multicolor_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs): + def draw_multicolor_balls_3D( + self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs + ): """Draws points as 3D spheres with different colours. - - If marker_color is an array of numbers, a coloraxis is created and values are converted to rgb. + + If marker_color is an array of numbers, a coloraxis is created and values are converted to rgb. """ - kwargs['marker'] = marker.copy() + kwargs["marker"] = marker.copy() - if 'color' in marker and np.array(marker['color']).dtype in (int, float): - coloraxis = kwargs['marker']['coloraxis'] + if "color" in marker and np.array(marker["color"]).dtype in (int, float): + coloraxis = kwargs["marker"]["coloraxis"] coloraxis = self._coloraxes[coloraxis] - kwargs['marker']['color'] = values_to_colors(kwargs['marker']['color'], coloraxis['colorscale'] or "viridis") + kwargs["marker"]["color"] = values_to_colors( + kwargs["marker"]["color"], coloraxis["colorscale"] or "viridis" + ) return self.draw_balls_3D(x, y, z, name=name, row=row, col=col, **kwargs) - - def draw_multisize_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs): + + def draw_multisize_balls_3D( + self, x, y, z, name=None, marker={}, row=None, col=None, **kwargs + ): """Draws points as 3D spheres with different sizes. Usually supported by the normal draw_balls_3D """ return self.draw_balls_3D(x, y, z, name=name, row=row, col=col, **kwargs) - def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, scale: float = 1, row=None, col=None, **kwargs): + def draw_arrows_3D( + self, + x, + y, + z, + dxyz, + arrowhead_scale=0.3, + arrowhead_angle=15, + scale: float = 1, + row=None, + col=None, + **kwargs, + ): """Draws multiple 3D arrows using the generic draw_line_3D method. Parameters @@ -633,8 +814,18 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, dxy_norm = np.linalg.norm(dxyz[:, :2], axis=1) # Some vectors might be only in the Z direction, which will result in dxy_norm being 0. # We avoid problems by dividinc - dx_p = np.divide(dxyz[:, 1], dxy_norm, where=dxy_norm != 0, out=np.zeros(dxyz.shape[0], dtype=np.float64)) - dy_p = np.divide(-dxyz[:, 0], dxy_norm, where=dxy_norm != 0, out=np.ones(dxyz.shape[0], dtype=np.float64)) + dx_p = np.divide( + dxyz[:, 1], + dxy_norm, + where=dxy_norm != 0, + out=np.zeros(dxyz.shape[0], dtype=np.float64), + ) + dy_p = np.divide( + -dxyz[:, 0], + dxy_norm, + where=dxy_norm != 0, + out=np.ones(dxyz.shape[0], dtype=np.float64), + ) # And then we build the rotation matrices. Since each arrow needs a unique rotation matrix, # we will have n 3x3 matrices, where n is the number of arrows, for each arrowhead tip. @@ -643,22 +834,29 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, # Rotation matrix to build the first arrowhead tip positions. rot_matrices = np.array( - [[c + (dx_p ** 2) * (1 - c), dx_p * dy_p * (1 - c), dy_p * s], - [dy_p * dx_p * (1 - c), c + (dy_p ** 2) * (1 - c), -dx_p * s], - [-dy_p * s, dx_p * s, np.full_like(dx_p, c)]]) + [ + [c + (dx_p**2) * (1 - c), dx_p * dy_p * (1 - c), dy_p * s], + [dy_p * dx_p * (1 - c), c + (dy_p**2) * (1 - c), -dx_p * s], + [-dy_p * s, dx_p * s, np.full_like(dx_p, c)], + ] + ) # The opposite rotation matrix, to get the other arrowhead's tip positions. inv_rots = rot_matrices.copy() inv_rots[[0, 1, 2, 2], [2, 2, 0, 1]] *= -1 # Calculate the tips of the arrow heads. - arrowhead_tips1 = final_xyz - np.einsum("ij...,...j->...i", rot_matrices, dxyz * arrowhead_scale) - arrowhead_tips2 = final_xyz - np.einsum("ij...,...j->...i", inv_rots, dxyz * arrowhead_scale) + arrowhead_tips1 = final_xyz - np.einsum( + "ij...,...j->...i", rot_matrices, dxyz * arrowhead_scale + ) + arrowhead_tips2 = final_xyz - np.einsum( + "ij...,...j->...i", inv_rots, dxyz * arrowhead_scale + ) # Now build an array with all the information to draw the arrows # This has shape (n_arrows * 7, 3). The information to draw an arrow # occupies 7 rows and the columns are the x and y coordinates. - arrows = np.empty((xyz.shape[0]*7, 3)) + arrows = np.empty((xyz.shape[0] * 7, 3)) arrows[0::7] = xyz arrows[1::7] = final_xyz @@ -668,15 +866,42 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, arrows[5::7] = arrowhead_tips2 arrows[6::7] = np.nan - return self.draw_line_3D(arrows[:, 0], arrows[:, 1], arrows[:, 2], row=row, col=col, **kwargs) + return self.draw_line_3D( + arrows[:, 0], arrows[:, 1], arrows[:, 2], row=row, col=col, **kwargs + ) - def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, **kwargs): + def draw_heatmap( + self, + values, + x=None, + y=None, + name=None, + zsmooth=False, + coloraxis=None, + row=None, + col=None, + **kwargs, + ): """Draws a heatmap following the specifications.""" - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_heatmap method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_heatmap method." + ) - def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name=None, row=None, col=None, **kwargs): + def draw_mesh_3D( + self, + vertices, + faces, + color=None, + opacity=None, + name=None, + row=None, + col=None, + **kwargs, + ): """Draws a 3D mesh following the specifications.""" - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a draw_mesh_3D method.") + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a draw_mesh_3D method." + ) def set_axis(self, **kwargs): """Sets the axis parameters. @@ -685,11 +910,13 @@ def set_axis(self, **kwargs): reference for consistency. Other frameworks should translate the calls to their functionality. """ - + def set_axes_equal(self): """Sets the axes equal.""" - raise NotImplementedError(f"{self.__class__.__name__} doesn't implement a set_axes_equal method.") - + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement a set_axes_equal method." + ) + def to(self, key: str): """Converts the figure to another backend. @@ -700,6 +927,7 @@ def to(self, key: str): """ return BACKENDS[key](self.plot_actions) + def get_figure(backend: str, plot_actions, *args, **kwargs) -> Figure: """Get a figure object. @@ -712,4 +940,4 @@ def get_figure(backend: str, plot_actions, *args, **kwargs) -> Figure: *args, **kwargs passed to the figure constructor """ - return BACKENDS[backend](plot_actions, *args, **kwargs) \ No newline at end of file + return BACKENDS[backend](plot_actions, *args, **kwargs) diff --git a/src/sisl/viz/figure/matplotlib.py b/src/sisl/viz/figure/matplotlib.py index e31ecd0a40..c05a1e635f 100644 --- a/src/sisl/viz/figure/matplotlib.py +++ b/src/sisl/viz/figure/matplotlib.py @@ -19,7 +19,7 @@ class MatplotlibFigure(Figure): If an attribute is not found on the backend, it is looked for in the axes. - On initialization, we also take the class attribute `_axes_defaults` (a dictionary) + On initialization, we also take the class attribute `_axes_defaults` (a dictionary) and run `self.axes.update` with those parameters. Therefore this parameter can be used to provide default parameters for the axes. """ @@ -84,7 +84,7 @@ def _init_figure_subplots(self, rows, cols, **kwargs): self.axes = np.expand_dims(self.axes, axis=1) return self.figure - + def _get_subplot_axes(self, row=None, col=None) -> plt.Axes: if row is None or col is None: # This is not a subplot @@ -93,12 +93,10 @@ def _get_subplot_axes(self, row=None, col=None) -> plt.Axes: return self.axes[row, col] def _iter_subplots(self, plot_actions): - it = zip(itertools.product(range(self._rows), range(self._cols)), plot_actions) # Start assigning each plot to a position of the layout for i, ((row, col), section_actions) in enumerate(it): - row_col_kwargs = {"row": row, "col": col} # active_axes = { # ax: f"{ax}axis" if row == 0 and col == 0 else f"{ax}axis{i + 1}" @@ -107,22 +105,28 @@ def _iter_subplots(self, plot_actions): sanitized_section_actions = [] for action in section_actions: - action_name = action['method'] + action_name = action["method"] if action_name.startswith("draw_"): - action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} + action = { + **action, + "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}, + } elif action_name.startswith("set_ax"): - action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} - + action = { + **action, + "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}, + } + sanitized_section_actions.append(action) yield sanitized_section_actions - def _init_figure_multiple_axes(self, multi_axes, plot_actions, **kwargs): - if len(multi_axes) > 1: self.figure = plt.figure() - self.axes = self.figure.add_axes([0.15, 0.1, 0.65, 0.8], axes_class=HostAxes) + self.axes = self.figure.add_axes( + [0.15, 0.1, 0.65, 0.8], axes_class=HostAxes + ) self._init_axes() multi_axis = "xy" @@ -130,12 +134,13 @@ def _init_figure_multiple_axes(self, multi_axes, plot_actions, **kwargs): self.figure = self._init_figure() multi_axis = multi_axes[0] - self._multi_axes[multi_axis] = self._init_multiaxis(multi_axis, len(plot_actions)) + self._multi_axes[multi_axis] = self._init_multiaxis( + multi_axis, len(plot_actions) + ) return self.figure def _init_multiaxis(self, axis, n): - axes = [self.axes] for i in range(n - 1): if axis == "x": @@ -157,22 +162,28 @@ def _init_multiaxis(self, axis, n): self.axes.parasites.append(new_axes) axes.append(new_axes) - + return axes def _iter_multiaxis(self, plot_actions): multi_axis = list(self._multi_axes)[0] for i, section_actions in enumerate(plot_actions): axes = self._multi_axes[multi_axis][i] - + sanitized_section_actions = [] for action in section_actions: - action_name = action['method'] + action_name = action["method"] if action_name.startswith("draw_"): - action = {**action, "kwargs": {**action.get("kwargs", {}), "_axes": axes}} + action = { + **action, + "kwargs": {**action.get("kwargs", {}), "_axes": axes}, + } elif action_name == "set_axis": - action = {**action, "kwargs": {**action.get("kwargs", {}), "_axes": axes}} - + action = { + **action, + "kwargs": {**action.get("kwargs", {}), "_axes": axes}, + } + sanitized_section_actions.append(action) yield sanitized_section_actions @@ -183,7 +194,7 @@ def __getattr__(self, key): raise AttributeError(key) def clear(self, layout=False): - """ Clears the plot canvas so that data can be reset + """Clears the plot canvas so that data can be reset Parameters -------- @@ -204,26 +215,66 @@ def get_ipywidget(self): def show(self, *args, **kwargs): return self.figure.show(*args, **kwargs) - def draw_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, _axes=None, **kwargs): + def draw_line( + self, + x, + y, + name=None, + line={}, + marker={}, + text=None, + row=None, + col=None, + _axes=None, + **kwargs, + ): marker_format = marker.get("symbol", "o") if marker else None marker_color = marker.get("color") axes = _axes or self._get_subplot_axes(row=row, col=col) return axes.plot( - x, y, color=line.get("color"), linewidth=line.get("width", 1), - marker=marker_format, markersize=marker.get("size"), markerfacecolor=marker_color, markeredgecolor=marker_color, - label=name + x, + y, + color=line.get("color"), + linewidth=line.get("width", 1), + marker=marker_format, + markersize=marker.get("size"), + markerfacecolor=marker_color, + markeredgecolor=marker_color, + label=name, ) - def draw_multicolor_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, _axes=None, **kwargs): + def draw_multicolor_line( + self, + x, + y, + name=None, + line={}, + marker={}, + text=None, + row=None, + col=None, + _axes=None, + **kwargs, + ): # This is heavily based on # https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html color = line.get("color") if not np.issubdtype(np.array(color).dtype, np.number): - return self.draw_multicolor_scatter(x, y, name=name, marker=line, text=text, row=row, col=col, _axes=_axes, **kwargs) + return self.draw_multicolor_scatter( + x, + y, + name=name, + marker=line, + text=text, + row=row, + col=col, + _axes=_axes, + **kwargs, + ) points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) @@ -234,7 +285,7 @@ def draw_multicolor_line(self, x, y, name=None, line={}, marker={}, text=None, r coloraxis = self._coloraxes.get(coloraxis) lc_kwargs["cmap"] = coloraxis.get("colorscale") if coloraxis.get("cmin") is not None: - lc_kwargs["norm"] = Normalize(coloraxis['cmin'], coloraxis['cmax']) + lc_kwargs["norm"] = Normalize(coloraxis["cmin"], coloraxis["cmax"]) lc = LineCollection(segments, **lc_kwargs) @@ -246,9 +297,21 @@ def draw_multicolor_line(self, x, y, name=None, line={}, marker={}, text=None, r axes.add_collection(lc) - #self._colorbar = axes.add_collection(lc) - - def draw_multisize_line(self, x, y, name=None, line={}, marker={}, text=None, row=None, col=None, _axes=None, **kwargs): + # self._colorbar = axes.add_collection(lc) + + def draw_multisize_line( + self, + x, + y, + name=None, + line={}, + marker={}, + text=None, + row=None, + col=None, + _axes=None, + **kwargs, + ): points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) @@ -261,9 +324,19 @@ def draw_multisize_line(self, x, y, name=None, line={}, marker={}, text=None, ro axes.add_collection(lc) - def draw_area_line(self, x, y, line={}, name=None, dependent_axis=None, row=None, col=None, _axes=None, **kwargs): - - width = line.get('width') + def draw_area_line( + self, + x, + y, + line={}, + name=None, + dependent_axis=None, + row=None, + col=None, + _axes=None, + **kwargs, + ): + width = line.get("width") if width is None: width = 1 spacing = width / 2 @@ -272,38 +345,83 @@ def draw_area_line(self, x, y, line={}, name=None, dependent_axis=None, row=None if dependent_axis in ("y", None): axes.fill_between( - x, y + spacing, y - spacing, - color=line.get('color'), label=name + x, y + spacing, y - spacing, color=line.get("color"), label=name ) elif dependent_axis == "x": axes.fill_betweenx( - y, x + spacing, x - spacing, - color=line.get('color'), label=name + y, x + spacing, x - spacing, color=line.get("color"), label=name ) else: - raise ValueError(f"dependent_axis must be one of 'x', 'y', or None, but was {dependent_axis}") + raise ValueError( + f"dependent_axis must be one of 'x', 'y', or None, but was {dependent_axis}" + ) - def draw_scatter(self, x, y, name=None, marker={}, text=None, zorder=2, row=None, col=None, _axes=None, meta={}, **kwargs): + def draw_scatter( + self, + x, + y, + name=None, + marker={}, + text=None, + zorder=2, + row=None, + col=None, + _axes=None, + meta={}, + **kwargs, + ): axes = _axes or self._get_subplot_axes(row=row, col=col) try: - return axes.scatter(x, y, c=marker.get("color"), s=marker.get("size", 1), cmap=marker.get("colorscale"), alpha=marker.get("opacity"), label=name, zorder=zorder, **kwargs) + return axes.scatter( + x, + y, + c=marker.get("color"), + s=marker.get("size", 1), + cmap=marker.get("colorscale"), + alpha=marker.get("opacity"), + label=name, + zorder=zorder, + **kwargs, + ) except TypeError as e: if str(e) == "alpha must be a float or None": - warn(f"Your matplotlib version doesn't support multiple opacity values, please upgrade to >=3.4 if you want to use opacity.") - return axes.scatter(x, y, c=marker.get("color"), s=marker.get("size", 1), cmap=marker.get("colorscale"), label=name, zorder=zorder, **kwargs) + warn( + f"Your matplotlib version doesn't support multiple opacity values, please upgrade to >=3.4 if you want to use opacity." + ) + return axes.scatter( + x, + y, + c=marker.get("color"), + s=marker.get("size", 1), + cmap=marker.get("colorscale"), + label=name, + zorder=zorder, + **kwargs, + ) else: raise e - + def draw_multicolor_scatter(self, *args, **kwargs): - marker = {**kwargs.pop("marker",{})} + marker = {**kwargs.pop("marker", {})} coloraxis = marker.get("coloraxis") if coloraxis is not None: coloraxis = self._coloraxes.get(coloraxis) marker["colorscale"] = coloraxis.get("colorscale") return super().draw_multicolor_scatter(*args, marker=marker, **kwargs) - - def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, _axes=None, **kwargs): + def draw_heatmap( + self, + values, + x=None, + y=None, + name=None, + zsmooth=False, + coloraxis=None, + row=None, + col=None, + _axes=None, + **kwargs, + ): extent = None if x is not None and y is not None: extent = [x[0], x[-1], y[0], y[-1]] @@ -316,30 +434,44 @@ def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, colorax vmax = coloraxis.get("cmax") axes.imshow( - values, - cmap=colorscale, - vmin=vmin, vmax=vmax, - label=name, extent=extent, - origin="lower" + values, + cmap=colorscale, + vmin=vmin, + vmax=vmax, + label=name, + extent=extent, + origin="lower", ) - - def set_axis(self, axis, range=None, title="", tickvals=None, ticktext=None, showgrid=False, row=None, col=None, _axes=None, **kwargs): + + def set_axis( + self, + axis, + range=None, + title="", + tickvals=None, + ticktext=None, + showgrid=False, + row=None, + col=None, + _axes=None, + **kwargs, + ): axes = _axes or self._get_subplot_axes(row=row, col=col) if range is not None: - updater = getattr(axes, f'set_{axis}lim') + updater = getattr(axes, f"set_{axis}lim") updater(*range) if title: - updater = getattr(axes, f'set_{axis}label') + updater = getattr(axes, f"set_{axis}label") updater(title) - + if tickvals is not None: - updater = getattr(axes, f'set_{axis}ticks') + updater = getattr(axes, f"set_{axis}ticks") updater(ticks=tickvals, labels=ticktext) axes.grid(visible=showgrid, axis=axis) - + def set_axes_equal(self, row=None, col=None, _axes=None): axes = _axes or self._get_subplot_axes(row=row, col=col) - axes.axis("equal") \ No newline at end of file + axes.axis("equal") diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py index 46909cc753..ce55895280 100644 --- a/src/sisl/viz/figure/plotly.py +++ b/src/sisl/viz/figure/plotly.py @@ -13,27 +13,55 @@ layout={ "plot_bgcolor": "white", "paper_bgcolor": "white", - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "black"), ("showgrid", False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, + **{ + f"{ax}_{key}": val + for ax, (key, val) in itertools.product( + ("xaxis", "yaxis"), + ( + ("visible", True), + ("showline", True), + ("linewidth", 1), + ("mirror", True), + ("color", "black"), + ("showgrid", False), + ("gridcolor", "#ccc"), + ("gridwidth", 1), + ("zeroline", False), + ("zerolinecolor", "#ccc"), + ("zerolinewidth", 1), + ("ticks", "outside"), + ("ticklen", 5), + ("ticksuffix", " "), + ), + ) + }, "hovermode": "closest", "scene": { - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis", "zaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "black"), ("showgrid", - False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", - "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, + **{ + f"{ax}_{key}": val + for ax, (key, val) in itertools.product( + ("xaxis", "yaxis", "zaxis"), + ( + ("visible", True), + ("showline", True), + ("linewidth", 1), + ("mirror", True), + ("color", "black"), + ("showgrid", False), + ("gridcolor", "#ccc"), + ("gridwidth", 1), + ("zeroline", False), + ("zerolinecolor", "#ccc"), + ("zerolinewidth", 1), + ("ticks", "outside"), + ("ticklen", 5), + ("ticksuffix", " "), + ), + ) + }, } - #"editrevision": True - #"title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} + # "editrevision": True + # "title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} }, ) @@ -41,29 +69,56 @@ layout={ "plot_bgcolor": "black", "paper_bgcolor": "black", - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "white"), ("showgrid", - False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, - "font": {'color': 'white'}, + **{ + f"{ax}_{key}": val + for ax, (key, val) in itertools.product( + ("xaxis", "yaxis"), + ( + ("visible", True), + ("showline", True), + ("linewidth", 1), + ("mirror", True), + ("color", "white"), + ("showgrid", False), + ("gridcolor", "#ccc"), + ("gridwidth", 1), + ("zeroline", False), + ("zerolinecolor", "#ccc"), + ("zerolinewidth", 1), + ("ticks", "outside"), + ("ticklen", 5), + ("ticksuffix", " "), + ), + ) + }, + "font": {"color": "white"}, "hovermode": "closest", "scene": { - **{f"{ax}_{key}": val for ax, (key, val) in itertools.product( - ("xaxis", "yaxis", "zaxis"), - (("visible", True), ("showline", True), ("linewidth", 1), ("mirror", True), - ("color", "white"), ("showgrid", - False), ("gridcolor", "#ccc"), ("gridwidth", 1), - ("zeroline", False), ("zerolinecolor", - "#ccc"), ("zerolinewidth", 1), - ("ticks", "outside"), ("ticklen", 5), ("ticksuffix", " ")) - )}, + **{ + f"{ax}_{key}": val + for ax, (key, val) in itertools.product( + ("xaxis", "yaxis", "zaxis"), + ( + ("visible", True), + ("showline", True), + ("linewidth", 1), + ("mirror", True), + ("color", "white"), + ("showgrid", False), + ("gridcolor", "#ccc"), + ("gridwidth", 1), + ("zeroline", False), + ("zerolinecolor", "#ccc"), + ("zerolinewidth", 1), + ("ticks", "outside"), + ("ticklen", 5), + ("ticksuffix", " "), + ), + ) + }, } - #"editrevision": True - #"title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} + # "editrevision": True + # "title": {"xref": "paper", "x": 0.5, "text": "Whhhhhhhat up", "pad": {"b": 0}} }, ) @@ -72,17 +127,19 @@ # so that it doesn't affect plots outside sisl. pio.templates.default = "sisl" + class PlotlyFigure(Figure): """Generic canvas for the plotly framework. - On initialization, a plotly.graph_objs.Figure object is created and stored + On initialization, a plotly.graph_objs.Figure object is created and stored under `self.figure`. If an attribute is not found on the backend, it is looked for in the figure. Therefore, you can apply all the methods that are appliable to a plotly figure! - On initialization, we also take the class attribute `_layout_defaults` (a dictionary) + On initialization, we also take the class attribute `_layout_defaults` (a dictionary) and run `update_layout` with those parameters. """ + _multi_axis = None _layout_defaults = {} @@ -91,24 +148,25 @@ def _init_figure(self, *args, **kwargs): self.figure = go.Figure() self.update_layout(**self._layout_defaults) return self - - def _init_figure_subplots(self, rows, cols, **kwargs): + def _init_figure_subplots(self, rows, cols, **kwargs): figure = self._init_figure() - figure.set_subplots(**{ - "rows": rows, "cols": cols, **kwargs, - }) + figure.set_subplots( + **{ + "rows": rows, + "cols": cols, + **kwargs, + } + ) return figure def _iter_subplots(self, plot_actions): - it = zip(itertools.product(range(self._rows), range(self._cols)), plot_actions) # Start assigning each plot to a position of the layout for i, ((row, col), section_actions) in enumerate(it): - row_col_kwargs = {"row": row + 1, "col": col + 1} active_axes = { ax: f"{ax}axis" if row == 0 and col == 0 else f"{ax}axis{i + 1}" @@ -117,21 +175,34 @@ def _iter_subplots(self, plot_actions): sanitized_section_actions = [] for action in section_actions: - action_name = action['method'] + action_name = action["method"] if action_name.startswith("draw_"): - action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} - action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} + action = { + **action, + "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}, + } + action["kwargs"]["meta"] = { + **action["kwargs"].get("meta", {}), + "i_plot": i, + } elif action_name.startswith("set_ax"): - action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} - + action = { + **action, + "kwargs": { + **action.get("kwargs", {}), + "_active_axes": active_axes, + }, + } + sanitized_section_actions.append(action) yield sanitized_section_actions def _init_multiaxis(self, axis, n): - - axes = [f"{axis}{i + 1}" if i > 0 else axis for i in range(n) ] - layout_axes = [f"{axis}axis{i + 1}" if i > 0 else f"{axis}axis" for i in range(n) ] + axes = [f"{axis}{i + 1}" if i > 0 else axis for i in range(n)] + layout_axes = [ + f"{axis}axis{i + 1}" if i > 0 else f"{axis}axis" for i in range(n) + ] if axis == "x": sides = ["bottom", "top"] elif axis == "y": @@ -141,47 +212,68 @@ def _init_multiaxis(self, axis, n): layout_updates = {} for ax, side in zip(layout_axes, itertools.cycle(sides)): - layout_updates[ax] = {'side': side, 'overlaying': axis} - layout_updates[f"{axis}axis"]['overlaying'] = None + layout_updates[ax] = {"side": side, "overlaying": axis} + layout_updates[f"{axis}axis"]["overlaying"] = None self.update_layout(**layout_updates) - + return layout_axes def _iter_multiaxis(self, plot_actions): - for i, section_actions in enumerate(plot_actions): active_axes = {ax: v[i] for ax, v in self._multi_axes.items()} - active_axes_kwargs = {f"{ax}axis": v.replace("axis", "") for ax, v in active_axes.items()} - + active_axes_kwargs = { + f"{ax}axis": v.replace("axis", "") for ax, v in active_axes.items() + } + sanitized_section_actions = [] for action in section_actions: - action_name = action['method'] + action_name = action["method"] if action_name.startswith("draw_"): - action = {**action, "kwargs": {**action.get("kwargs", {}), **active_axes_kwargs}} - action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} + action = { + **action, + "kwargs": {**action.get("kwargs", {}), **active_axes_kwargs}, + } + action["kwargs"]["meta"] = { + **action["kwargs"].get("meta", {}), + "i_plot": i, + } elif action_name.startswith("set_ax"): - action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} - + action = { + **action, + "kwargs": { + **action.get("kwargs", {}), + "_active_axes": active_axes, + }, + } + sanitized_section_actions.append(action) yield sanitized_section_actions - - def _iter_same_axes(self, plot_actions): + def _iter_same_axes(self, plot_actions): for i, section_actions in enumerate(plot_actions): - sanitized_section_actions = [] for action in section_actions: - action_name = action['method'] + action_name = action["method"] if action_name.startswith("draw_"): action = {**action, "kwargs": action.get("kwargs", {})} - action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} - + action["kwargs"]["meta"] = { + **action["kwargs"].get("meta", {}), + "i_plot": i, + } + sanitized_section_actions.append(action) yield sanitized_section_actions - def _init_figure_animated(self, frame_names: Optional[Sequence[str]] = None, frame_duration: int = 500, transition: int = 300, redraw: bool = False, **kwargs): + def _init_figure_animated( + self, + frame_names: Optional[Sequence[str]] = None, + frame_duration: int = 500, + transition: int = 300, + redraw: bool = False, + **kwargs, + ): self._animation_settings = { "frame_names": frame_names, "frame_duration": frame_duration, @@ -191,9 +283,8 @@ def _init_figure_animated(self, frame_names: Optional[Sequence[str]] = None, fra self._animate_frame_names = frame_names self._animate_init_kwargs = kwargs return self._init_figure(**kwargs) - - def _iter_animation(self, plot_actions): + def _iter_animation(self, plot_actions): frame_duration = self._animation_settings["frame_duration"] transition = self._animation_settings["transition"] redraw = self._animation_settings["redraw"] @@ -206,18 +297,26 @@ def _iter_animation(self, plot_actions): frames = [] for i, section_actions in enumerate(plot_actions): - sanitized_section_actions = [] for action in section_actions: - action_name = action['method'] + action_name = action["method"] if action_name.startswith("draw_"): action = {**action, "kwargs": action.get("kwargs", {})} - action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} + action["kwargs"]["meta"] = { + **action["kwargs"].get("meta", {}), + "i_plot": i, + } yield sanitized_section_actions # Create a frame and append it - frames.append(go.Frame(name=frame_names[i],data=self.figure.data, layout=self.figure.layout)) + frames.append( + go.Frame( + name=frame_names[i], + data=self.figure.data, + layout=self.figure.layout, + ) + ) # Reinit the figure self._init_figure(**self._animate_init_kwargs) @@ -225,14 +324,19 @@ def _iter_animation(self, plot_actions): self.figure.update(data=frames[0].data, frames=frames) slider_steps = [ - {"args": [ - [frame["name"]], - {"frame": {"duration": int(frame_duration), "redraw": redraw}, - "mode": "immediate", - "transition": {"duration": transition}} - ], - "label": frame["name"], - "method": "animate"} for frame in self.figure.frames + { + "args": [ + [frame["name"]], + { + "frame": {"duration": int(frame_duration), "redraw": redraw}, + "mode": "immediate", + "transition": {"duration": transition}, + }, + ], + "label": frame["name"], + "method": "animate", + } + for frame in self.figure.frames ] slider = { @@ -241,38 +345,53 @@ def _iter_animation(self, plot_actions): "xanchor": "left", "currentvalue": { "font": {"size": 20}, - #"prefix": "Bands file:", + # "prefix": "Bands file:", "visible": True, - "xanchor": "right" + "xanchor": "right", }, - #"transition": {"duration": 300, "easing": "cubic-in-out"}, + # "transition": {"duration": 300, "easing": "cubic-in-out"}, "pad": {"b": 10, "t": 50}, "len": 0.9, "x": 0.1, "y": 0, - "steps": slider_steps + "steps": slider_steps, } # Buttons to play and pause the animation updatemenus = [ - - {'type': 'buttons', - 'buttons': [ - { - 'label': '▶', - 'method': 'animate', - 'args': [None, {"frame": {"duration": int(frame_duration), "redraw": redraw}, - "fromcurrent": True, "transition": {"duration": 100}}], - }, - - { - 'label': '⏸', - 'method': 'animate', - 'args': [[None], {"frame": {"duration": 0}, "redraw": redraw, - 'mode': 'immediate', - "transition": {"duration": 0}}], - } - ]} + { + "type": "buttons", + "buttons": [ + { + "label": "▶", + "method": "animate", + "args": [ + None, + { + "frame": { + "duration": int(frame_duration), + "redraw": redraw, + }, + "fromcurrent": True, + "transition": {"duration": 100}, + }, + ], + }, + { + "label": "⏸", + "method": "animate", + "args": [ + [None], + { + "frame": {"duration": 0}, + "redraw": redraw, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + }, + ], + } ] self.update_layout(sliders=[slider], updatemenus=updatemenus) @@ -286,7 +405,7 @@ def show(self, *args, **kwargs): return self.figure.show(*args, **kwargs) def clear(self, frames=True, layout=False): - """ Clears the plot canvas so that data can be reset + """Clears the plot canvas so that data can be reset Parameters -------- @@ -308,34 +427,47 @@ def clear(self, frames=True, layout=False): # -------------------------------- # METHODS TO STANDARIZE BACKENDS # -------------------------------- - def init_coloraxis(self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs): + def init_coloraxis( + self, name, cmin=None, cmax=None, cmid=None, colorscale=None, **kwargs + ): if len(self._coloraxes) == 0: - kwargs['ax_name'] = "coloraxis" + kwargs["ax_name"] = "coloraxis" else: - kwargs['ax_name'] = f'coloraxis{len(self._coloraxes) + 1}' + kwargs["ax_name"] = f"coloraxis{len(self._coloraxes) + 1}" super().init_coloraxis(name, cmin, cmax, cmid, colorscale, **kwargs) - - ax_name = kwargs['ax_name'] - self.update_layout(**{ax_name: {"colorscale": colorscale, "cmin": cmin, "cmax": cmax, "cmid": cmid}}) - def _get_coloraxis_name(self, coloraxis: Optional[str]): + ax_name = kwargs["ax_name"] + self.update_layout( + **{ + ax_name: { + "colorscale": colorscale, + "cmin": cmin, + "cmax": cmax, + "cmid": cmid, + } + } + ) + def _get_coloraxis_name(self, coloraxis: Optional[str]): if coloraxis in self._coloraxes: - return self._coloraxes[coloraxis]['ax_name'] + return self._coloraxes[coloraxis]["ax_name"] else: return coloraxis def _handle_multicolor_scatter(self, marker, scatter_kwargs): - - if 'coloraxis' in marker: + if "coloraxis" in marker: marker = marker.copy() - coloraxis = marker['coloraxis'] + coloraxis = marker["coloraxis"] if coloraxis is not None: - scatter_kwargs['hovertemplate'] = "x: %{x:.2f}
y: %{y:.2f}
" + coloraxis + ": %{marker.color:.2f}" - marker['coloraxis'] = self._get_coloraxis_name(coloraxis) - + scatter_kwargs["hovertemplate"] = ( + "x: %{x:.2f}
y: %{y:.2f}
" + + coloraxis + + ": %{marker.color:.2f}" + ) + marker["coloraxis"] = self._get_coloraxis_name(coloraxis) + return marker def draw_line(self, x, y, name=None, line={}, row=None, col=None, **kwargs): @@ -351,33 +483,48 @@ def draw_line(self, x, y, name=None, line={}, row=None, col=None, **kwargs): mode += "+text" # Finally, we add the trace. - self.add_trace({ - 'type': 'scatter', - 'x': x, - 'y': y, - 'mode': mode, - 'name': name, - 'line': {k: v for k, v in line.items() if k != "opacity"}, - 'opacity': opacity, - "meta": kwargs.pop("meta", {}), - **kwargs, - }, row=row, col=col) + self.add_trace( + { + "type": "scatter", + "x": x, + "y": y, + "mode": mode, + "name": name, + "line": {k: v for k, v in line.items() if k != "opacity"}, + "opacity": opacity, + "meta": kwargs.pop("meta", {}), + **kwargs, + }, + row=row, + col=col, + ) def draw_multicolor_line(self, *args, **kwargs): - kwargs['marker_line_width'] = 0 + kwargs["marker_line_width"] = 0 super().draw_multicolor_line(*args, **kwargs) def draw_multisize_line(self, *args, **kwargs): - kwargs['marker_line_width'] = 0 + kwargs["marker_line_width"] = 0 super().draw_multisize_line(*args, **kwargs) - def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=None, row=None, col=None, **kwargs): + def draw_area_line( + self, + x, + y, + name=None, + line={}, + text=None, + dependent_axis=None, + row=None, + col=None, + **kwargs, + ): chunk_x = x chunk_y = y - width = line.get('width') + width = line.get("width") if width is None: width = 1 chunk_spacing = width / 2 @@ -385,11 +532,17 @@ def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=Non if dependent_axis is None: # We draw the area line using the perpendicular direction to the line, because we don't know which # direction should we draw it in. - normal = np.array([- np.gradient(y), np.gradient(x)]).T + normal = np.array([-np.gradient(y), np.gradient(x)]).T norms = normal / np.linalg.norm(normal, axis=1).reshape(-1, 1) - x = [*(chunk_x + norms[:, 0] * chunk_spacing), *reversed(chunk_x - norms[:, 0] * chunk_spacing)] - y = [*(chunk_y + norms[:, 1] * chunk_spacing), *reversed(chunk_y - norms[:, 1] * chunk_spacing)] + x = [ + *(chunk_x + norms[:, 0] * chunk_spacing), + *reversed(chunk_x - norms[:, 0] * chunk_spacing), + ] + y = [ + *(chunk_y + norms[:, 1] * chunk_spacing), + *reversed(chunk_y - norms[:, 1] * chunk_spacing), + ] elif dependent_axis == "y": x = [*chunk_x, *reversed(chunk_x)] y = [*(chunk_y + chunk_spacing), *reversed(chunk_y - chunk_spacing)] @@ -399,26 +552,29 @@ def draw_area_line(self, x, y, name=None, line={}, text=None, dependent_axis=Non else: raise ValueError(f"Invalid dependent axis: {dependent_axis}") - self.add_trace({ - "type": "scatter", - "mode": "lines", - "x": x, - "y": y, - "line": {"width": 0, "color": line.get('color')}, - "name": name, - "legendgroup": name, - "showlegend": kwargs.pop("showlegend", None), - "fill": "toself", - "meta": kwargs.pop("meta", {}) - }, row=row, col=col) + self.add_trace( + { + "type": "scatter", + "mode": "lines", + "x": x, + "y": y, + "line": {"width": 0, "color": line.get("color")}, + "name": name, + "legendgroup": name, + "showlegend": kwargs.pop("showlegend", None), + "fill": "toself", + "meta": kwargs.pop("meta", {}), + }, + row=row, + col=col, + ) def draw_scatter(self, x, y, name=None, marker={}, **kwargs): marker.pop("dash", None) self.draw_line(x, y, name, marker=marker, mode="markers", **kwargs) - + def draw_multicolor_scatter(self, *args, **kwargs): - - kwargs['marker'] = self._handle_multicolor_scatter(kwargs['marker'], kwargs) + kwargs["marker"] = self._handle_multicolor_scatter(kwargs["marker"], kwargs) super().draw_multicolor_scatter(*args, **kwargs) @@ -426,21 +582,19 @@ def draw_line_3D(self, x, y, z, **kwargs): self.draw_line(x, y, type="scatter3d", z=z, **kwargs) def draw_multicolor_line_3D(self, x, y, z, **kwargs): - kwargs['line'] = self._handle_multicolor_scatter(kwargs['line'], kwargs) + kwargs["line"] = self._handle_multicolor_scatter(kwargs["line"], kwargs) super().draw_multicolor_line_3D(x, y, z, **kwargs) def draw_scatter_3D(self, *args, **kwargs): self.draw_line_3D(*args, mode="markers", **kwargs) - + def draw_multicolor_scatter_3D(self, *args, **kwargs): - - kwargs['marker'] = self._handle_multicolor_scatter(kwargs['marker'], kwargs) + kwargs["marker"] = self._handle_multicolor_scatter(kwargs["marker"], kwargs) super().draw_multicolor_scatter_3D(*args, **kwargs) def draw_balls_3D(self, x, y, z, name=None, marker={}, **kwargs): - style = {} for k in ("size", "color", "opacity"): val = marker.get(k) @@ -450,36 +604,82 @@ def draw_balls_3D(self, x, y, z, name=None, marker={}, **kwargs): style[k] = val - iterator = enumerate(zip(np.array(x), np.array(y), np.array(z), style["size"], style["color"], style["opacity"])) + iterator = enumerate( + zip( + np.array(x), + np.array(y), + np.array(z), + style["size"], + style["color"], + style["opacity"], + ) + ) meta = kwargs.pop("meta", {}) showlegend = True for i, (sp_x, sp_y, sp_z, sp_size, sp_color, sp_opacity) in iterator: self.draw_ball_3D( - xyz=[sp_x, sp_y, sp_z], - size=sp_size, color=sp_color, opacity=sp_opacity, + xyz=[sp_x, sp_y, sp_z], + size=sp_size, + color=sp_color, + opacity=sp_opacity, name=f"{name}_{i}", - legendgroup=name, showlegend=showlegend, - meta=meta + legendgroup=name, + showlegend=showlegend, + meta=meta, ) showlegend = False return - def draw_ball_3D(self, xyz, size, color="gray", name=None, vertices=15, row=None, col=None, **kwargs): - self.add_trace({ - 'type': 'mesh3d', - **{key: val for key, val in sphere(center=xyz, r=size, vertices=vertices).items()}, - 'alphahull': 0, - 'color': color, - 'showscale': False, - 'name': name, - 'meta': {"position": '({:.2f}, {:.2f}, {:.2f})'.format(*xyz), "meta": kwargs.pop("meta", {})}, - 'hovertemplate': '%{meta.position}', - **kwargs - }, row=None, col=None) - - def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_angle=20, arrowhead_scale=0.3, scale: float = 1, row=None, col=None, **kwargs): + def draw_ball_3D( + self, + xyz, + size, + color="gray", + name=None, + vertices=15, + row=None, + col=None, + **kwargs, + ): + self.add_trace( + { + "type": "mesh3d", + **{ + key: val + for key, val in sphere( + center=xyz, r=size, vertices=vertices + ).items() + }, + "alphahull": 0, + "color": color, + "showscale": False, + "name": name, + "meta": { + "position": "({:.2f}, {:.2f}, {:.2f})".format(*xyz), + "meta": kwargs.pop("meta", {}), + }, + "hovertemplate": "%{meta.position}", + **kwargs, + }, + row=None, + col=None, + ) + + def draw_arrows_3D( + self, + x, + y, + z, + dxyz, + arrowhead_angle=20, + arrowhead_scale=0.3, + scale: float = 1, + row=None, + col=None, + **kwargs, + ): """Draws 3D arrows in plotly using a combination of a scatter3D and a Cone trace.""" # Make sure we are dealing with numpy arrays xyz = np.array([x, y, z]).T @@ -491,13 +691,13 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_angle=20, arrowhead_scale=0.3, color = line.get("color") if color is None: color = "red" - line['color'] = color + line["color"] = color # 3D lines don't support opacity line.pop("opacity", None) name = kwargs.get("name", "Arrows") - arrows_coords = np.empty((xyz.shape[0]*3, 3), dtype=np.float64) + arrows_coords = np.empty((xyz.shape[0] * 3, 3), dtype=np.float64) arrows_coords[0::3] = xyz arrows_coords[1::3] = final_xyz @@ -507,72 +707,109 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_angle=20, arrowhead_scale=0.3, rows_cols = {} if row is not None: - rows_cols['rows'] = [row, row] + rows_cols["rows"] = [row, row] if col is not None: - rows_cols['cols'] = [col, col] + rows_cols["cols"] = [col, col] meta = kwargs.pop("meta", {}) - - - self.figure.add_traces([{ - "x": arrows_coords[:, 0], - "y": arrows_coords[:, 1], - "z": arrows_coords[:, 2], - "mode": "lines", - "type": "scatter3d", - "hoverinfo": "none", - "line": line, - "legendgroup": name, - "name": f"{name} lines", - "showlegend": False, - "meta": meta - }, - { - "type": "cone", - "x": conebase_xyz[:, 0], - "y": conebase_xyz[:, 1], - "z": conebase_xyz[:, 2], - "u": arrowhead_scale * dxyz[:, 0], - "v": arrowhead_scale * dxyz[:, 1], - "w": arrowhead_scale * dxyz[:, 2], - "hovertemplate": "[%{u}, %{v}, %{w}]", - "sizemode": "absolute", - "sizeref": arrowhead_scale * np.linalg.norm(dxyz, axis=1).max() / 2, - "colorscale": [[0, color], [1, color]], - "showscale": False, - "legendgroup": name, - "name": name, - "showlegend": True, - "meta": meta - }], **rows_cols) - - def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None, **kwargs): - - self.add_trace({ - 'type': 'heatmap', 'z': values, - 'x': x, 'y': y, - 'name': name, - 'zsmooth': zsmooth, - 'coloraxis': self._get_coloraxis_name(coloraxis), - 'meta': kwargs.pop("meta", {}), - }, row=row, col=col) - - def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name=None, row=None, col=None, **kwargs): + self.figure.add_traces( + [ + { + "x": arrows_coords[:, 0], + "y": arrows_coords[:, 1], + "z": arrows_coords[:, 2], + "mode": "lines", + "type": "scatter3d", + "hoverinfo": "none", + "line": line, + "legendgroup": name, + "name": f"{name} lines", + "showlegend": False, + "meta": meta, + }, + { + "type": "cone", + "x": conebase_xyz[:, 0], + "y": conebase_xyz[:, 1], + "z": conebase_xyz[:, 2], + "u": arrowhead_scale * dxyz[:, 0], + "v": arrowhead_scale * dxyz[:, 1], + "w": arrowhead_scale * dxyz[:, 2], + "hovertemplate": "[%{u}, %{v}, %{w}]", + "sizemode": "absolute", + "sizeref": arrowhead_scale * np.linalg.norm(dxyz, axis=1).max() / 2, + "colorscale": [[0, color], [1, color]], + "showscale": False, + "legendgroup": name, + "name": name, + "showlegend": True, + "meta": meta, + }, + ], + **rows_cols, + ) + + def draw_heatmap( + self, + values, + x=None, + y=None, + name=None, + zsmooth=False, + coloraxis=None, + row=None, + col=None, + **kwargs, + ): + self.add_trace( + { + "type": "heatmap", + "z": values, + "x": x, + "y": y, + "name": name, + "zsmooth": zsmooth, + "coloraxis": self._get_coloraxis_name(coloraxis), + "meta": kwargs.pop("meta", {}), + }, + row=row, + col=col, + ) + + def draw_mesh_3D( + self, + vertices, + faces, + color=None, + opacity=None, + name=None, + row=None, + col=None, + **kwargs, + ): x, y, z = vertices.T I, J, K = faces.T - self.add_trace(dict( - type="mesh3d", - x=x, y=y, z=z, - i=I, j=J, k=K, - color=color, - opacity=opacity, - name=name, - showlegend=True, - meta=kwargs.pop("meta", {}), - **kwargs - ), row=row, col=col) + self.add_trace( + dict( + type="mesh3d", + x=x, + y=y, + z=z, + i=I, + j=J, + k=K, + color=color, + opacity=opacity, + name=name, + showlegend=True, + meta=kwargs.pop("meta", {}), + **kwargs, + ), + row=row, + col=col, + ) def set_axis(self, axis, _active_axes={}, **kwargs): if axis in _active_axes: @@ -591,6 +828,6 @@ def set_axis(self, axis, _active_axes={}, **kwargs): def set_axes_equal(self, _active_axes={}): x_axis = _active_axes.get("x", "xaxis") y_axis = _active_axes.get("y", "yaxis").replace("axis", "") - + self.update_layout({x_axis: {"scaleanchor": y_axis, "scaleratio": 1}}) - self.update_layout(scene_aspectmode="data") \ No newline at end of file + self.update_layout(scene_aspectmode="data") diff --git a/src/sisl/viz/figure/py3dmol.py b/src/sisl/viz/figure/py3dmol.py index d63a815e41..69d220a34b 100644 --- a/src/sisl/viz/figure/py3dmol.py +++ b/src/sisl/viz/figure/py3dmol.py @@ -16,19 +16,38 @@ class Py3DmolFigure(Figure): def _init_figure(self, *args, **kwargs): self.figure = py3Dmol.view() - def draw_line(self, x, y, name="", line={}, marker={}, text=None, row=None, col=None, **kwargs): + def draw_line( + self, x, y, name="", line={}, marker={}, text=None, row=None, col=None, **kwargs + ): z = np.full_like(x, 0) # x = self._2D_scale[0] * x # y = self._2D_scale[1] * y - return self.draw_line_3D(x, y, z, name=name, line=line, marker=marker, text=text, row=row, col=col, **kwargs) - - def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs): + return self.draw_line_3D( + x, + y, + z, + name=name, + line=line, + marker=marker, + text=text, + row=row, + col=col, + **kwargs, + ) + + def draw_scatter( + self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs + ): z = np.full_like(x, 0) # x = self._2D_scale[0] * x # y = self._2D_scale[1] * y - return self.draw_scatter_3D(x, y, z, name=name, marker=marker, text=text, row=row, col=col, **kwargs) + return self.draw_scatter_3D( + x, y, z, name=name, marker=marker, text=text, row=row, col=col, **kwargs + ) - def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, **kwargs): + def draw_line_3D( + self, x, y, z, line={}, name="", collection=None, frame=None, **kwargs + ): """Draws a line.""" xyz = np.array([x, y, z], dtype=float).T @@ -44,7 +63,7 @@ def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, * # Now loop through all segments using the known breakpoints for start_i, end_i in zip(breakpoint_indices, breakpoint_indices[1:]): # Get the coordinates of the segment - segment_xyz = xyz[start_i+1: end_i] + segment_xyz = xyz[start_i + 1 : end_i] # If there is nothing to draw, go to next segment if len(segment_xyz) == 0: @@ -52,41 +71,77 @@ def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, * points = [{"x": x, "y": y, "z": z} for x, y, z in segment_xyz] - # If there's only two points, py3dmol doesn't display the curve, + # If there's only two points, py3dmol doesn't display the curve, # probably because it can not smooth it. if len(points) == 2: points.append(points[-1]) - self.figure.addCurve(dict( - points=points, - radius=line.get("width", 0.1), - color=line.get("color"), - opacity=line.get('opacity', 1.) or 1., - smooth=1 - )) + self.figure.addCurve( + dict( + points=points, + radius=line.get("width", 0.1), + color=line.get("color"), + opacity=line.get("opacity", 1.0) or 1.0, + smooth=1, + ) + ) return self - def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, collection=None, frame=None, **kwargs): + def draw_balls_3D( + self, + x, + y, + z, + name=None, + marker={}, + row=None, + col=None, + collection=None, + frame=None, + **kwargs, + ): style = { "color": marker.get("color", "gray"), - "opacity": marker.get("opacity", 1.), - "size": marker.get("size", 1.), + "opacity": marker.get("opacity", 1.0), + "size": marker.get("size", 1.0), } for k, v in style.items(): - if (not isinstance(v, (collections.abc.Sequence, np.ndarray))) or isinstance(v, str): + if ( + not isinstance(v, (collections.abc.Sequence, np.ndarray)) + ) or isinstance(v, str): style[k] = itertools.repeat(v) - for i, (x_i, y_i, z_i, color, opacity, size) in enumerate(zip(x, y, z, style["color"], style["opacity"], style["size"])): - self.figure.addSphere(dict( - center={"x": float(x_i), "y": float(y_i), "z": float(z_i)}, radius=size, color=color, opacity=opacity, - quality=5., # This does not work, but sphere quality is really bad by default - )) + for i, (x_i, y_i, z_i, color, opacity, size) in enumerate( + zip(x, y, z, style["color"], style["opacity"], style["size"]) + ): + self.figure.addSphere( + dict( + center={"x": float(x_i), "y": float(y_i), "z": float(z_i)}, + radius=size, + color=color, + opacity=opacity, + quality=5.0, # This does not work, but sphere quality is really bad by default + ) + ) draw_scatter_3D = draw_balls_3D - def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, scale: float = 1, row=None, col=None, line={},**kwargs): + def draw_arrows_3D( + self, + x, + y, + z, + dxyz, + arrowhead_scale=0.3, + arrowhead_angle=15, + scale: float = 1, + row=None, + col=None, + line={}, + **kwargs, + ): """Draws multiple arrows using the generic draw_line method. Parameters @@ -112,29 +167,42 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, dxyz = np.array(dxyz) * scale for (x, y, z), (dx, dy, dz) in zip(xyz, dxyz): - - self.figure.addArrow(dict( - start={"x": x, "y": y, "z": z}, - end={"x": x + dx, "y": y + dy, "z": z + dz}, - radius=line.get("width", 0.1), - color=line.get("color"), - opacity=line.get("opacity", 1.), - radiusRatio=2, - mid=(1 - arrowhead_scale), - )) - - def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name="Mesh", wireframe=False, row=None, col=None, **kwargs): - + self.figure.addArrow( + dict( + start={"x": x, "y": y, "z": z}, + end={"x": x + dx, "y": y + dy, "z": z + dz}, + radius=line.get("width", 0.1), + color=line.get("color"), + opacity=line.get("opacity", 1.0), + radiusRatio=2, + mid=(1 - arrowhead_scale), + ) + ) + + def draw_mesh_3D( + self, + vertices, + faces, + color=None, + opacity=None, + name="Mesh", + wireframe=False, + row=None, + col=None, + **kwargs, + ): def vec_to_dict(a, labels="xyz"): - return dict(zip(labels,a)) - - self.figure.addCustom(dict( - vertexArr=[vec_to_dict(v) for v in vertices.astype(float)], - faceArr=[int(x) for f in faces for x in f], - color=color, - opacity=float(opacity or 1.), - wireframe=wireframe - )) + return dict(zip(labels, a)) + + self.figure.addCustom( + dict( + vertexArr=[vec_to_dict(v) for v in vertices.astype(float)], + faceArr=[int(x) for f in faces for x in f], + color=color, + opacity=float(opacity or 1.0), + wireframe=wireframe, + ) + ) def set_axis(self, *args, **kwargs): """There are no axes titles and these kind of things in py3dmol. @@ -145,4 +213,4 @@ def set_axes_equal(self, *args, **kwargs): def show(self, *args, **kwargs): self.figure.zoomTo() - return self.figure.show() \ No newline at end of file + return self.figure.show() diff --git a/src/sisl/viz/plot.py b/src/sisl/viz/plot.py index fbbecf8ca7..0ab3b41689 100644 --- a/src/sisl/viz/plot.py +++ b/src/sisl/viz/plot.py @@ -10,15 +10,19 @@ def __getattr__(self, key): return getattr(self.nodes.output.get(), key) else: return super().__getattr__(key) - + def merge(self, *others, **kwargs): from .plots.merged import merge_plots + return merge_plots(self, *others, **kwargs) - + def update_settings(self, *args, **kwargs): - deprecate("f{self.__class__.__name__}.update_settings is deprecated. Please use update_inputs.", "0.15") + deprecate( + "f{self.__class__.__name__}.update_settings is deprecated. Please use update_inputs.", + "0.15", + ) return self.update_inputs(*args, **kwargs) - + @classmethod def plot_class_key(cls) -> str: return cls.__name__.replace("Plot", "").lower() diff --git a/src/sisl/viz/plots/bands.py b/src/sisl/viz/plots/bands.py index 144d4dea4b..a4b97ea47c 100644 --- a/src/sisl/viz/plots/bands.py +++ b/src/sisl/viz/plots/bands.py @@ -18,22 +18,30 @@ from .orbital_groups_plot import OrbitalGroupsPlot -def bands_plot(bands_data: BandsData, - Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", - bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, - bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, - spindown_style: StyleSpec = {"color": "blue", "width": 1}, +def bands_plot( + bands_data: BandsData, + Erange: Optional[Tuple[float, float]] = None, + E0: float = 0.0, + E_axis: Literal["x", "y"] = "y", + bands_range: Optional[Tuple[int, int]] = None, + spin: Optional[Literal[0, 1]] = None, + bands_style: StyleSpec = {"color": "black", "width": 1, "opacity": 1}, + spindown_style: StyleSpec = {"color": "blue", "width": 1}, colorscale: Optional[str] = None, - gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, + gap: bool = False, + gap_tol: float = 0.01, + gap_color: str = "red", + gap_marker: dict = {"size": 7}, + direct_gaps_only: bool = False, custom_gaps: Sequence[Dict] = [], - line_mode: Literal["line", "scatter", "area_line"] = "line", - backend: str = "plotly" + line_mode: Literal["line", "scatter", "area_line"] = "line", + backend: str = "plotly", ) -> Figure: """Plots band structure energies, with plentiful of customization options. Parameters ---------- - bands_data: + bands_data: The object containing the data to plot. Erange: The energy range to plot. @@ -77,33 +85,55 @@ def bands_plot(bands_data: BandsData, bands_data = accept_data(bands_data, cls=BandsData, check=True) # Filter the bands - filtered_bands = filter_bands(bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin) + filtered_bands = filter_bands( + bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin + ) # Add the styles - styled_bands = style_bands(filtered_bands, bands_style=bands_style, spindown_style=spindown_style) + styled_bands = style_bands( + filtered_bands, bands_style=bands_style, spindown_style=spindown_style + ) # Determine what goes on each axis x = matches(E_axis, "x", ret_true="E", ret_false="k") y = matches(E_axis, "y", ret_true="E", ret_false="k") - + # Get the actions to plot lines - bands_plottings = draw_xarray_xy(data=styled_bands, x=x, y=y, set_axrange=True, what=line_mode, colorscale=colorscale, dependent_axis=E_axis) + bands_plottings = draw_xarray_xy( + data=styled_bands, + x=x, + y=y, + set_axrange=True, + what=line_mode, + colorscale=colorscale, + dependent_axis=E_axis, + ) # Gap calculation gap_info = calculate_gap(filtered_bands) # Plot it if the user has asked for it. - gaps_plottings = draw_gaps(bands_data, gap, gap_info, gap_tol, gap_color, gap_marker, direct_gaps_only, custom_gaps, E_axis=E_axis) + gaps_plottings = draw_gaps( + bands_data, + gap, + gap_info, + gap_tol, + gap_color, + gap_marker, + direct_gaps_only, + custom_gaps, + E_axis=E_axis, + ) all_plottings = combined(bands_plottings, gaps_plottings, composite_method=None) return get_figure(backend=backend, plot_actions=all_plottings) + def _default_random_color(x): return x.get("color") or random_color() def _group_traces(actions): - seen_groups = [] new_actions = [] @@ -111,40 +141,48 @@ def _group_traces(actions): if action["method"].startswith("draw_"): group = action["kwargs"].get("name") action = action.copy() - action['kwargs']['legendgroup'] = group + action["kwargs"]["legendgroup"] = group if group in seen_groups: action["kwargs"]["showlegend"] = False else: seen_groups.append(group) - + new_actions.append(action) - + return new_actions # I keep the fatbands plot here so that one can see how similar they are. # I am yet to find a nice solution for extending workflows. -def fatbands_plot(bands_data: BandsData, - Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", - bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, - bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, - spindown_style: StyleSpec = {"color": "blue", "width": 1}, - gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, +def fatbands_plot( + bands_data: BandsData, + Erange: Optional[Tuple[float, float]] = None, + E0: float = 0.0, + E_axis: Literal["x", "y"] = "y", + bands_range: Optional[Tuple[int, int]] = None, + spin: Optional[Literal[0, 1]] = None, + bands_style: StyleSpec = {"color": "black", "width": 1, "opacity": 1}, + spindown_style: StyleSpec = {"color": "blue", "width": 1}, + gap: bool = False, + gap_tol: float = 0.01, + gap_color: str = "red", + gap_marker: dict = {"size": 7}, + direct_gaps_only: bool = False, custom_gaps: Sequence[Dict] = [], - bands_mode: Literal["line", "scatter", "area_line"] = "line", + bands_mode: Literal["line", "scatter", "area_line"] = "line", # Fatbands inputs groups: OrbitalQueries = [], - fatbands_var: str = "norm2", + fatbands_var: str = "norm2", fatbands_mode: Literal["line", "scatter", "area_line"] = "area_line", - fatbands_scale: float = 1., - backend: str = "plotly" + fatbands_scale: float = 1.0, + backend: str = "plotly", ) -> Figure: """Plots band structure energies showing the contribution of orbitals to each state. Parameters ---------- - bands_data: + bands_data: The object containing the data to plot. Erange: The energy range to plot. @@ -193,52 +231,98 @@ def fatbands_plot(bands_data: BandsData, bands_data = accept_data(bands_data, cls=BandsData, check=True) # Filter the bands - filtered_bands = filter_bands(bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin) + filtered_bands = filter_bands( + bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin + ) # Add the styles - styled_bands = style_bands(filtered_bands, bands_style=bands_style, spindown_style=spindown_style) + styled_bands = style_bands( + filtered_bands, bands_style=bands_style, spindown_style=spindown_style + ) # Process fatbands orbital_manager = get_orbital_queries_manager( bands_data, key_gens={ "color": _default_random_color, - } + }, ) fatbands_data = reduce_orbital_data( - filtered_bands, groups=groups, orb_dim="orb", spin_dim="spin", sanitize_group=orbital_manager, - group_vars=('color', 'dash'), groups_dim="group", drop_empty=True, + filtered_bands, + groups=groups, + orb_dim="orb", + spin_dim="spin", + sanitize_group=orbital_manager, + group_vars=("color", "dash"), + groups_dim="group", + drop_empty=True, spin_reduce=np.sum, ) - scaled_fatbands_data = scale_variable(fatbands_data, var=fatbands_var, scale=fatbands_scale, default_value=1, allow_not_present=True) + scaled_fatbands_data = scale_variable( + fatbands_data, + var=fatbands_var, + scale=fatbands_scale, + default_value=1, + allow_not_present=True, + ) # Determine what goes on each axis x = matches(E_axis, "x", ret_true="E", ret_false="k") y = matches(E_axis, "y", ret_true="E", ret_false="k") - sanitized_fatbands_mode = matches(groups, [], ret_true="none", ret_false=fatbands_mode) - + sanitized_fatbands_mode = matches( + groups, [], ret_true="none", ret_false=fatbands_mode + ) + # Get the actions to plot lines fatbands_plottings = draw_xarray_xy( - data=scaled_fatbands_data, x=x, y=y, color="color", width=fatbands_var, what=sanitized_fatbands_mode, dependent_axis=E_axis, - name="group" + data=scaled_fatbands_data, + x=x, + y=y, + color="color", + width=fatbands_var, + what=sanitized_fatbands_mode, + dependent_axis=E_axis, + name="group", ) grouped_fatbands_plottings = _group_traces(fatbands_plottings) - bands_plottings = draw_xarray_xy(data=styled_bands, x=x, y=y, set_axrange=True, what=bands_mode, dependent_axis=E_axis) + bands_plottings = draw_xarray_xy( + data=styled_bands, + x=x, + y=y, + set_axrange=True, + what=bands_mode, + dependent_axis=E_axis, + ) # Gap calculation gap_info = calculate_gap(filtered_bands) # Plot it if the user has asked for it. - gaps_plottings = draw_gaps(bands_data, gap, gap_info, gap_tol, gap_color, gap_marker, direct_gaps_only, custom_gaps, E_axis=E_axis) + gaps_plottings = draw_gaps( + bands_data, + gap, + gap_info, + gap_tol, + gap_color, + gap_marker, + direct_gaps_only, + custom_gaps, + E_axis=E_axis, + ) - all_plottings = combined(grouped_fatbands_plottings, bands_plottings, gaps_plottings, composite_method=None) + all_plottings = combined( + grouped_fatbands_plottings, + bands_plottings, + gaps_plottings, + composite_method=None, + ) return get_figure(backend=backend, plot_actions=all_plottings) -class BandsPlot(Plot): +class BandsPlot(Plot): function = staticmethod(bands_plot) -class FatbandsPlot(OrbitalGroupsPlot): - function = staticmethod(fatbands_plot) \ No newline at end of file +class FatbandsPlot(OrbitalGroupsPlot): + function = staticmethod(fatbands_plot) diff --git a/src/sisl/viz/plots/geometry.py b/src/sisl/viz/plots/geometry.py index 580f8afaf0..5ebc90760b 100644 --- a/src/sisl/viz/plots/geometry.py +++ b/src/sisl/viz/plots/geometry.py @@ -33,73 +33,89 @@ def _get_atom_mode(drawing_mode, ndim): - if drawing_mode is None: if ndim == 3: - return 'balls' + return "balls" else: - return 'scatter' - + return "scatter" + return drawing_mode -def _get_arrow_plottings(atoms_data, arrows, nsc=[1,1,1]): - + +def _get_arrow_plottings(atoms_data, arrows, nsc=[1, 1, 1]): reps = np.prod(nsc) actions = [] atoms_data = atoms_data.unstack("sc_atom") for arrows_spec in arrows: - filtered = atoms_data.sel(atom=arrows_spec['atoms']) - dxy = arrows_spec['data'][arrows_spec['atoms']] - dxy = np.tile(np.ravel(dxy), reps).reshape(-1, arrows_spec['data'].shape[-1]) + filtered = atoms_data.sel(atom=arrows_spec["atoms"]) + dxy = arrows_spec["data"][arrows_spec["atoms"]] + dxy = np.tile(np.ravel(dxy), reps).reshape(-1, arrows_spec["data"].shape[-1]) # If it is a 1D plot, make sure that the arrows have two coordinates, being 0 the second one. if dxy.shape[-1] == 1: dxy = np.array([dxy[:, 0], np.zeros_like(dxy[:, 0])]).T kwargs = {} - kwargs['line'] = {'color': arrows_spec['color'], 'width': arrows_spec['width'], 'opacity': arrows_spec.get('opacity', 1)} - kwargs['name'] = arrows_spec['name'] - kwargs['arrowhead_scale'] = arrows_spec['arrowhead_scale'] - kwargs['arrowhead_angle'] = arrows_spec['arrowhead_angle'] - kwargs['annotate'] = arrows_spec.get('annotate', False) - kwargs['scale'] = arrows_spec['scale'] + kwargs["line"] = { + "color": arrows_spec["color"], + "width": arrows_spec["width"], + "opacity": arrows_spec.get("opacity", 1), + } + kwargs["name"] = arrows_spec["name"] + kwargs["arrowhead_scale"] = arrows_spec["arrowhead_scale"] + kwargs["arrowhead_angle"] = arrows_spec["arrowhead_angle"] + kwargs["annotate"] = arrows_spec.get("annotate", False) + kwargs["scale"] = arrows_spec["scale"] if dxy.shape[-1] < 3: - action = plot_actions.draw_arrows(x=np.ravel(filtered.x), y=np.ravel(filtered.y), dxy=dxy, **kwargs) + action = plot_actions.draw_arrows( + x=np.ravel(filtered.x), y=np.ravel(filtered.y), dxy=dxy, **kwargs + ) else: - action = plot_actions.draw_arrows_3D(x=np.ravel(filtered.x), y=np.ravel(filtered.y), z=np.ravel(filtered.z), dxyz=dxy, **kwargs) + action = plot_actions.draw_arrows_3D( + x=np.ravel(filtered.x), + y=np.ravel(filtered.y), + z=np.ravel(filtered.z), + dxyz=dxy, + **kwargs, + ) actions.append(action) return actions -def _sanitize_scale(scale: float, ndim: int, ndim_scale: Tuple[float, float, float] = (16, 16, 1)): - return ndim_scale[ndim-1] * scale - -def geometry_plot(geometry: Geometry, - axes: Axes = ["x", "y", "z"], - atoms: AtomsArgument = None, - atoms_style: Sequence[AtomsStyleSpec] = [], - atoms_scale: float = 1., + +def _sanitize_scale( + scale: float, ndim: int, ndim_scale: Tuple[float, float, float] = (16, 16, 1) +): + return ndim_scale[ndim - 1] * scale + + +def geometry_plot( + geometry: Geometry, + axes: Axes = ["x", "y", "z"], + atoms: AtomsArgument = None, + atoms_style: Sequence[AtomsStyleSpec] = [], + atoms_scale: float = 1.0, atoms_colorscale: Optional[str] = None, drawing_mode: Literal["scatter", "balls", None] = None, - bind_bonds_to_ats: bool = True, - points_per_bond: int = 20, - bonds_style: StyleSpec = {}, - bonds_scale: float = 1., - bonds_colorscale: Optional[str] = None, - show_atoms: bool = True, - show_bonds: bool = True, - show_cell: Literal["box", "axes", False] = "box", - cell_style: StyleSpec = {}, + bind_bonds_to_ats: bool = True, + points_per_bond: int = 20, + bonds_style: StyleSpec = {}, + bonds_scale: float = 1.0, + bonds_colorscale: Optional[str] = None, + show_atoms: bool = True, + show_bonds: bool = True, + show_cell: Literal["box", "axes", False] = "box", + cell_style: StyleSpec = {}, nsc: Tuple[int, int, int] = (1, 1, 1), atoms_ndim_scale: Tuple[float, float, float] = (16, 16, 1), bonds_ndim_scale: Tuple[float, float, float] = (1, 1, 10), - dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, - arrows: Sequence[AtomArrowSpec] = (), + dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, + arrows: Sequence[AtomArrowSpec] = (), backend="plotly", ) -> Figure: """Plots a geometry structure, with plentiful of customization options. - + Parameters ---------- geometry: @@ -168,89 +184,123 @@ def geometry_plot(geometry: Geometry, filtered_atoms = select(atoms_dataset, "atom", atoms_filter) tiled_atoms = tile_data_sc(filtered_atoms, nsc=nsc) sc_atoms = stack_sc_data(tiled_atoms, newname="sc_atom", dims=["atom"]) - projected_atoms = project_to_axes(sc_atoms, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d) + projected_atoms = project_to_axes( + sc_atoms, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d + ) atoms_scale = _sanitize_scale(atoms_scale, ndim, atoms_ndim_scale) final_atoms = scale_variable(projected_atoms, "size", scale=atoms_scale) atom_mode = _get_atom_mode(drawing_mode, ndim) atom_plottings = draw_xarray_xy( - data=final_atoms, x="x", y="y", z=z, width="size", what=atom_mode, colorscale=atoms_colorscale, - set_axequal=True, name="Atoms" + data=final_atoms, + x="x", + y="y", + z=z, + width="size", + what=atom_mode, + colorscale=atoms_colorscale, + set_axequal=True, + name="Atoms", ) - + # Here we start to process bonds bonds = find_all_bonds(geometry) show_bonds = matches(ndim, 1, False, show_bonds) styled_bonds = style_bonds(bonds, bonds_style) bonds_dataset = add_xyz_to_bonds_dataset(styled_bonds) - bonds_filter = sanitize_bonds_selection(bonds_dataset, sanitized_atoms, bind_bonds_to_ats, show_bonds) + bonds_filter = sanitize_bonds_selection( + bonds_dataset, sanitized_atoms, bind_bonds_to_ats, show_bonds + ) filtered_bonds = select(bonds_dataset, "bond_index", bonds_filter) tiled_bonds = tile_data_sc(filtered_bonds, nsc=nsc) - + projected_bonds = project_to_axes(tiled_bonds, axes=axes) bond_lines = bonds_to_lines(projected_bonds, points_per_bond=points_per_bond) bonds_scale = _sanitize_scale(bonds_scale, ndim, bonds_ndim_scale) final_bonds = scale_variable(bond_lines, "width", scale=bonds_scale) - bond_plottings = draw_xarray_xy(data=final_bonds, x="x", y="y", z=z, set_axequal=True, name="Bonds", colorscale=bonds_colorscale) - + bond_plottings = draw_xarray_xy( + data=final_bonds, + x="x", + y="y", + z=z, + set_axequal=True, + name="Bonds", + colorscale=bonds_colorscale, + ) + # And now the cell show_cell = matches(ndim, 1, False, show_cell) cell_plottings = cell_plot_actions( - cell=geometry, show_cell=show_cell, cell_style=cell_style, - axes=axes, dataaxis_1d=dataaxis_1d + cell=geometry, + show_cell=show_cell, + cell_style=cell_style, + axes=axes, + dataaxis_1d=dataaxis_1d, ) - + # And the arrows - arrow_data = sanitize_arrows(geometry, arrows, atoms=sanitized_atoms, ndim=ndim, axes=axes) + arrow_data = sanitize_arrows( + geometry, arrows, atoms=sanitized_atoms, ndim=ndim, axes=axes + ) arrow_plottings = _get_arrow_plottings(projected_atoms, arrow_data, nsc=nsc) - all_actions = plot_actions.combined(bond_plottings, atom_plottings, cell_plottings, arrow_plottings, composite_method=None) - + all_actions = plot_actions.combined( + bond_plottings, + atom_plottings, + cell_plottings, + arrow_plottings, + composite_method=None, + ) + return get_figure(backend=backend, plot_actions=all_actions) -class GeometryPlot(Plot): +class GeometryPlot(Plot): function = staticmethod(geometry_plot) @property def geometry(self): - return self.nodes.inputs['geometry']._output - + return self.nodes.inputs["geometry"]._output + + _T = TypeVar("_T", list, tuple, dict) - -def _sites_specs_to_atoms_specs(sites_specs: _T) -> _T: + +def _sites_specs_to_atoms_specs(sites_specs: _T) -> _T: if isinstance(sites_specs, dict): if "sites" in sites_specs: sites_specs = sites_specs.copy() - sites_specs['atoms'] = sites_specs.pop('sites') + sites_specs["atoms"] = sites_specs.pop("sites") return sites_specs else: - return type(sites_specs)(_sites_specs_to_atoms_specs(style_spec) for style_spec in sites_specs) - + return type(sites_specs)( + _sites_specs_to_atoms_specs(style_spec) for style_spec in sites_specs + ) + + def sites_plot( - sites_obj: BrillouinZone, - axes: Axes = ["x", "y", "z"], - sites: AtomsArgument = None, - sites_style: Sequence[AtomsStyleSpec] = [], - sites_scale: float = 1., + sites_obj: BrillouinZone, + axes: Axes = ["x", "y", "z"], + sites: AtomsArgument = None, + sites_style: Sequence[AtomsStyleSpec] = [], + sites_scale: float = 1.0, sites_name: str = "Sites", sites_colorscale: Optional[str] = None, drawing_mode: Literal["scatter", "balls", "line", None] = None, - show_cell: Literal["box", "axes", False] = False, - cell_style: StyleSpec = {}, + show_cell: Literal["box", "axes", False] = False, + cell_style: StyleSpec = {}, nsc: Tuple[int, int, int] = (1, 1, 1), sites_ndim_scale: Tuple[float, float, float] = (1, 1, 1), - dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, - arrows: Sequence[AtomArrowSpec] = (), + dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None, + arrows: Sequence[AtomArrowSpec] = (), backend="plotly", ) -> Figure: """Plots sites from an object that can be parsed into a geometry. The only differences between this plot and a geometry plot is the naming of the inputs and the fact that there are no options to plot bonds. - + Parameters ---------- sites_obj: @@ -302,32 +352,52 @@ def sites_plot( tiled_sites = tile_data_sc(filtered_sites, nsc=nsc) sc_sites = stack_sc_data(tiled_sites, newname="sc_atom", dims=["atom"]) sites_units = get_sites_units(sites_obj) - projected_sites = project_to_axes(sc_sites, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d, cartesian_units=sites_units) + projected_sites = project_to_axes( + sc_sites, + axes=axes, + sort_by_depth=True, + dataaxis_1d=dataaxis_1d, + cartesian_units=sites_units, + ) sites_scale = _sanitize_scale(sites_scale, ndim, sites_ndim_scale) final_sites = scale_variable(projected_sites, "size", scale=sites_scale) sites_mode = _get_atom_mode(drawing_mode, ndim) site_plottings = draw_xarray_xy( - data=final_sites, x="x", y="y", z=z, width="size", what=sites_mode, colorscale=sites_colorscale, - set_axequal=True, name=sites_name, + data=final_sites, + x="x", + y="y", + z=z, + width="size", + what=sites_mode, + colorscale=sites_colorscale, + set_axequal=True, + name=sites_name, ) - + # And now the cell show_cell = matches(ndim, 1, False, show_cell) cell_plottings = cell_plot_actions( - cell=fake_geometry, show_cell=show_cell, cell_style=cell_style, - axes=axes, dataaxis_1d=dataaxis_1d + cell=fake_geometry, + show_cell=show_cell, + cell_style=cell_style, + axes=axes, + dataaxis_1d=dataaxis_1d, ) - + # And the arrows atom_arrows = _sites_specs_to_atoms_specs(arrows) - arrow_data = sanitize_arrows(fake_geometry, atom_arrows, atoms=sanitized_sites, ndim=ndim, axes=axes) + arrow_data = sanitize_arrows( + fake_geometry, atom_arrows, atoms=sanitized_sites, ndim=ndim, axes=axes + ) arrow_plottings = _get_arrow_plottings(projected_sites, arrow_data, nsc=nsc) - all_actions = plot_actions.combined(site_plottings, cell_plottings, arrow_plottings, composite_method=None) - + all_actions = plot_actions.combined( + site_plottings, cell_plottings, arrow_plottings, composite_method=None + ) + return get_figure(backend=backend, plot_actions=all_actions) -class SitesPlot(Plot): +class SitesPlot(Plot): function = staticmethod(sites_plot) diff --git a/src/sisl/viz/plots/grid.py b/src/sisl/viz/plots/grid.py index c522eb65f2..d89d36abe4 100644 --- a/src/sisl/viz/plots/grid.py +++ b/src/sisl/viz/plots/grid.py @@ -33,40 +33,52 @@ from .geometry import geometry_plot -def _get_structure_plottings(plot_geom, geometry, axes, nsc, geom_kwargs={},): +def _get_structure_plottings( + plot_geom, + geometry, + axes, + nsc, + geom_kwargs={}, +): if plot_geom: - geom_kwargs = ChainMap(geom_kwargs, {"axes": axes, "geometry": geometry, "nsc": nsc, "show_cell": False}) + geom_kwargs = ChainMap( + geom_kwargs, + {"axes": axes, "geometry": geometry, "nsc": nsc, "show_cell": False}, + ) plot_actions = geometry_plot(**geom_kwargs).plot_actions else: plot_actions = [] return plot_actions + def grid_plot( - grid: Optional[Grid] = None, - axes: Axes = ["z"], - represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", + grid: Optional[Grid] = None, + axes: Axes = ["z"], + represent: Literal[ + "real", "imag", "mod", "phase", "deg_phase", "rad_phase" + ] = "real", transforms: Sequence[Union[str, Callable]] = (), - reduce_method: Literal["average", "sum"] = "average", + reduce_method: Literal["average", "sum"] = "average", boundary_mode: str = "grid-wrap", - nsc: Tuple[int, int, int] = (1, 1, 1), - interp: Tuple[int, int, int] = (1, 1, 1), - isos: Sequence[dict] = [], + nsc: Tuple[int, int, int] = (1, 1, 1), + interp: Tuple[int, int, int] = (1, 1, 1), + isos: Sequence[dict] = [], smooth: bool = False, - colorscale: Optional[str] = None, - crange: Optional[Tuple[float, float]] = None, + colorscale: Optional[str] = None, + crange: Optional[Tuple[float, float]] = None, cmid: Optional[float] = None, - show_cell: Literal["box", "axes", False] = "box", + show_cell: Literal["box", "axes", False] = "box", cell_style: dict = {}, x_range: Optional[Sequence[float]] = None, - y_range: Optional[Sequence[float]] = None, + y_range: Optional[Sequence[float]] = None, z_range: Optional[Sequence[float]] = None, - plot_geom: bool = False, - geom_kwargs: dict = {}, - backend: str = "plotly" + plot_geom: bool = False, + geom_kwargs: dict = {}, + backend: str = "plotly", ) -> Figure: """Plots a grid, with plentiful of customization options. - + Parameters ---------- grid: @@ -103,7 +115,7 @@ def grid_plot( cell_style: Style specification for the cell. See the showcase notebooks for examples. x_range: - The range of the x axis to take into account. + The range of the x axis to take into account. Even if the X axis is not displayed! This is important because the reducing operation will only be applied on this range. y_range: @@ -120,12 +132,11 @@ def grid_plot( Keyword arguments to pass to the geometry plot of the associated geometry. backend: The backend to use to generate the figure. - + See also ---------- scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed. """ - axes = sanitize_axes(axes) @@ -136,64 +147,88 @@ def grid_plot( tiled_grid = tile_grid(grid_repr, nsc=nsc) ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode) - + grid_axes = get_grid_axes(ort_grid, axes=axes) transformed_grid = apply_transforms(ort_grid, transforms) - subbed_grid = sub_grid(transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range) + subbed_grid = sub_grid( + transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range + ) - reduced_grid = reduce_grid(subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes) + reduced_grid = reduce_grid( + subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes + ) interp_grid = interpolate_grid(reduced_grid, interp=interp) # Finally, here comes the plotting! grid_ds = grid_to_dataarray(interp_grid, axes=axes, grid_axes=grid_axes, nsc=nsc) - grid_plottings = draw_grid(data=grid_ds, isos=isos, colorscale=colorscale, crange=crange, cmid=cmid, smooth=smooth) + grid_plottings = draw_grid( + data=grid_ds, + isos=isos, + colorscale=colorscale, + crange=crange, + cmid=cmid, + smooth=smooth, + ) # Process the cell as well cell_plottings = cell_plot_actions( - cell=grid, show_cell=show_cell, cell_style=cell_style, + cell=grid, + show_cell=show_cell, + cell_style=cell_style, axes=axes, ) # And maybe plot the strucuture - geom_plottings = _get_structure_plottings(plot_geom=plot_geom, geometry=geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=nsc) + geom_plottings = _get_structure_plottings( + plot_geom=plot_geom, + geometry=geometry, + geom_kwargs=geom_kwargs, + axes=axes, + nsc=nsc, + ) - all_plottings = combined(grid_plottings, cell_plottings, geom_plottings, composite_method=None) + all_plottings = combined( + grid_plottings, cell_plottings, geom_plottings, composite_method=None + ) return get_figure(backend=backend, plot_actions=all_plottings) + def wavefunction_plot( - eigenstate: EigenstateData, - i: int = 0, - geometry: Optional[Geometry] = None, - grid_prec: float = 0.2, - # All grid inputs. - grid: Optional[Grid] = None, - axes: Axes = ["z"], - represent: Literal["real", "imag", "mod", "phase", "deg_phase", "rad_phase"] = "real", + eigenstate: EigenstateData, + i: int = 0, + geometry: Optional[Geometry] = None, + grid_prec: float = 0.2, + # All grid inputs. + grid: Optional[Grid] = None, + axes: Axes = ["z"], + represent: Literal[ + "real", "imag", "mod", "phase", "deg_phase", "rad_phase" + ] = "real", transforms: Sequence[Union[str, Callable]] = (), - reduce_method: Literal["average", "sum"] = "average", + reduce_method: Literal["average", "sum"] = "average", boundary_mode: str = "grid-wrap", - nsc: Tuple[int, int, int] = (1, 1, 1), - interp: Tuple[int, int, int] = (1, 1, 1), - isos: Sequence[dict] = [], + nsc: Tuple[int, int, int] = (1, 1, 1), + interp: Tuple[int, int, int] = (1, 1, 1), + isos: Sequence[dict] = [], smooth: bool = False, - colorscale: Optional[str] = None, - crange: Optional[Tuple[float, float]] = None, + colorscale: Optional[str] = None, + crange: Optional[Tuple[float, float]] = None, cmid: Optional[float] = None, - show_cell: Literal["box", "axes", False] = "box", + show_cell: Literal["box", "axes", False] = "box", cell_style: dict = {}, x_range: Optional[Sequence[float]] = None, - y_range: Optional[Sequence[float]] = None, + y_range: Optional[Sequence[float]] = None, z_range: Optional[Sequence[float]] = None, - plot_geom: bool = False, - geom_kwargs: dict = {}, - backend: str = "plotly" + plot_geom: bool = False, + geom_kwargs: dict = {}, + backend: str = "plotly", ) -> Figure: """Plots a wavefunction in real space. - + Parameters ---------- eigenstate: @@ -201,7 +236,7 @@ def wavefunction_plot( i: The index of the eigenstate to plot. geometry: - Geometry to use to project the eigenstate to real space. + Geometry to use to project the eigenstate to real space. If None, the geometry associated with the eigenstate is used. grid_prec: The precision of the grid where the wavefunction is projected. @@ -239,7 +274,7 @@ def wavefunction_plot( cell_style: Style specification for the cell. See the showcase notebooks for examples. x_range: - The range of the x axis to take into account. + The range of the x axis to take into account. Even if the X axis is not displayed! This is important because the reducing operation will only be applied on this range. y_range: @@ -256,19 +291,21 @@ def wavefunction_plot( Keyword arguments to pass to the geometry plot of the associated geometry. backend: The backend to use to generate the figure. - + See also ---------- scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed. """ - + # Create a grid with the wavefunction in it. i_eigenstate = get_eigenstate(eigenstate, i) geometry = eigenstate_geometry(eigenstate, geometry=geometry) tiled_geometry = tile_if_k(geometry=geometry, nsc=nsc, eigenstate=i_eigenstate) grid_nsc = get_grid_nsc(nsc=nsc, eigenstate=i_eigenstate) - grid = project_wavefunction(eigenstate=i_eigenstate, grid_prec=grid_prec, grid=grid, geometry=tiled_geometry) + grid = project_wavefunction( + eigenstate=i_eigenstate, grid_prec=grid_prec, grid=grid, geometry=tiled_geometry + ) # Grid processing axes = sanitize_axes(axes) @@ -278,42 +315,66 @@ def wavefunction_plot( tiled_grid = tile_grid(grid_repr, nsc=grid_nsc) ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode) - + grid_axes = get_grid_axes(ort_grid, axes=axes) transformed_grid = apply_transforms(ort_grid, transforms) - subbed_grid = sub_grid(transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range) + subbed_grid = sub_grid( + transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range + ) - reduced_grid = reduce_grid(subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes) + reduced_grid = reduce_grid( + subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes + ) interp_grid = interpolate_grid(reduced_grid, interp=interp) # Finally, here comes the plotting! - grid_ds = grid_to_dataarray(interp_grid, axes=axes, grid_axes=grid_axes, nsc=grid_nsc) - grid_plottings = draw_grid(data=grid_ds, isos=isos, colorscale=colorscale, crange=crange, cmid=cmid, smooth=smooth) + grid_ds = grid_to_dataarray( + interp_grid, axes=axes, grid_axes=grid_axes, nsc=grid_nsc + ) + grid_plottings = draw_grid( + data=grid_ds, + isos=isos, + colorscale=colorscale, + crange=crange, + cmid=cmid, + smooth=smooth, + ) # Process the cell as well cell_plottings = cell_plot_actions( - cell=grid, show_cell=show_cell, cell_style=cell_style, + cell=grid, + show_cell=show_cell, + cell_style=cell_style, axes=axes, ) # And maybe plot the strucuture - geom_plottings = _get_structure_plottings(plot_geom=plot_geom, geometry=tiled_geometry, geom_kwargs=geom_kwargs, axes=axes, nsc=grid_nsc) + geom_plottings = _get_structure_plottings( + plot_geom=plot_geom, + geometry=tiled_geometry, + geom_kwargs=geom_kwargs, + axes=axes, + nsc=grid_nsc, + ) - all_plottings = combined(grid_plottings, cell_plottings, geom_plottings, composite_method=None) + all_plottings = combined( + grid_plottings, cell_plottings, geom_plottings, composite_method=None + ) return get_figure(backend=backend, plot_actions=all_plottings) -class GridPlot(Plot): +class GridPlot(Plot): function = staticmethod(grid_plot) -class WavefunctionPlot(GridPlot): +class WavefunctionPlot(GridPlot): function = staticmethod(wavefunction_plot) + # The following commented code is from the old viz module, where the GridPlot had a scan method. # It looks very nice, but probably should be reimplemented as a standalone function that plots a grid slice, # and then merge those grid slices to create a scan. diff --git a/src/sisl/viz/plots/merged.py b/src/sisl/viz/plots/merged.py index 0cc1fbf652..f13285e493 100644 --- a/src/sisl/viz/plots/merged.py +++ b/src/sisl/viz/plots/merged.py @@ -5,10 +5,13 @@ from ..plotters.plot_actions import combined -def merge_plots(*figures: Figure, - composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = "multiple", +def merge_plots( + *figures: Figure, + composite_method: Optional[ + Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"] + ] = "multiple", backend: Literal["plotly", "matplotlib", "py3dmol", "blender"] = "plotly", - **kwargs + **kwargs, ) -> Figure: """Combines multiple plots into a single figure. @@ -30,7 +33,7 @@ def merge_plots(*figures: Figure, plot_actions = combined( *[fig.plot_actions for fig in figures], composite_method=composite_method, - **kwargs + **kwargs, ) return get_figure(plot_actions=plot_actions, backend=backend) diff --git a/src/sisl/viz/plots/orbital_groups_plot.py b/src/sisl/viz/plots/orbital_groups_plot.py index 4f26884409..eb1c73b86a 100644 --- a/src/sisl/viz/plots/orbital_groups_plot.py +++ b/src/sisl/viz/plots/orbital_groups_plot.py @@ -3,7 +3,7 @@ class OrbitalGroupsPlot(Plot): """Contains methods to manipulate an input accepting groups of orbitals. - + Plots that need this functionality should inherit from this class. """ @@ -19,7 +19,7 @@ def _matches_group(self, group, query, iReq=None): return True return ("name" in group and group.get("name") in query) or iReq in query - + def groups(self, *i_or_names): """Gets the groups that match your query @@ -34,11 +34,15 @@ def groups(self, *i_or_names): If no query is provided, all the groups will be matched """ - return [req for i, req in enumerate(self.get_input(self._orbital_groups_input_key)) if self._matches_group(req, i_or_names, i)] + return [ + req + for i, req in enumerate(self.get_input(self._orbital_groups_input_key)) + if self._matches_group(req, i_or_names, i) + ] + + def add_group(self, group={}, clean=False, **kwargs): + """Adds a new orbitals group. - def add_group(self, group = {}, clean=False, **kwargs): - """Adds a new orbitals group. - The new group can be passed as a dict or as keyword arguments. The keyword arguments will overwrite what has been passed as a dict if there is conflict. @@ -55,7 +59,11 @@ def add_group(self, group = {}, clean=False, **kwargs): """ group = {**group, **kwargs} - groups = [group] if clean else [*self.get_input(self._orbital_groups_input_key), group] + groups = ( + [group] + if clean + else [*self.get_input(self._orbital_groups_input_key), group] + ) return self.update_inputs(**{self._orbital_groups_input_key: groups}) def remove_groups(self, *i_or_names, all=False): @@ -75,7 +83,11 @@ def remove_groups(self, *i_or_names, all=False): if all: groups = [] else: - groups = [req for i, req in enumerate(self.get_input(self._orbital_groups_input_key)) if not self._matches_group(req, i_or_names, i)] + groups = [ + req + for i, req in enumerate(self.get_input(self._orbital_groups_input_key)) + if not self._matches_group(req, i_or_names, i) + ] return self.update_inputs(**{self._orbital_groups_input_key: groups}) @@ -104,7 +116,17 @@ def update_groups(self, *i_or_names, **kwargs): return self.update_inputs(**{self._orbital_groups_input_key: groups}) - def split_groups(self, *i_or_names, on="species", only=None, exclude=None, remove=True, clean=False, ignore_constraints=False, **kwargs): + def split_groups( + self, + *i_or_names, + on="species", + only=None, + exclude=None, + remove=True, + clean=False, + ignore_constraints=False, + **kwargs, + ): """Splits the orbital groups into multiple groups. Parameters @@ -134,7 +156,7 @@ def split_groups(self, *i_or_names, on="species", only=None, exclude=None, remov If False, all the groups that come from the method will be drawn on top of what is already there. ignore_constraints: boolean or array-like, optional - determines whether constraints (imposed by the group to be splitted) + determines whether constraints (imposed by the group to be splitted) on the parameters that we want to split along should be taken into consideration. If `False`: all constraints considered. @@ -172,20 +194,28 @@ def split_groups(self, *i_or_names, on="species", only=None, exclude=None, remov groups = [] for req in reqs: new_groups = queries_manager._split_query( - req, on=on, only=only, exclude=exclude, - ignore_constraints=ignore_constraints, **kwargs + req, + on=on, + only=only, + exclude=exclude, + ignore_constraints=ignore_constraints, + **kwargs, ) groups.extend(new_groups) if remove: - old_groups = [req for i, req in enumerate(old_groups) if not self._matches_group(req, i_or_names, i)] + old_groups = [ + req + for i, req in enumerate(old_groups) + if not self._matches_group(req, i_or_names, i) + ] if not clean: groups = [*old_groups, *groups] return self.update_inputs(**{self._orbital_groups_input_key: groups}) - + def split_orbs(self, on="species", only=None, exclude=None, clean=True, **kwargs): """ Splits the orbitals into different groups. @@ -213,4 +243,6 @@ def split_orbs(self, on="species", only=None, exclude=None, clean=True, **kwargs will split on the different orbitals but will take only those that belong to carbon atoms. """ - return self.split_groups(on=on, only=only, exclude=exclude, clean=clean, **kwargs) + return self.split_groups( + on=on, only=only, exclude=exclude, clean=clean, **kwargs + ) diff --git a/src/sisl/viz/plots/pdos.py b/src/sisl/viz/plots/pdos.py index 2d6a8b90d4..f9636b9c60 100644 --- a/src/sisl/viz/plots/pdos.py +++ b/src/sisl/viz/plots/pdos.py @@ -19,11 +19,11 @@ def pdos_plot( pdos_data: PDOSData, - groups: Sequence[OrbitalStyleQuery]=[{"name": "DOS"}], + groups: Sequence[OrbitalStyleQuery] = [{"name": "DOS"}], Erange: Tuple[float, float] = (-2, 2), - E_axis: Literal["x", "y"] = "x", + E_axis: Literal["x", "y"] = "x", line_mode: Literal["line", "scatter", "area_line"] = "line", - line_scale: float = 1., + line_scale: float = 1.0, backend: str = "plotly", ) -> Figure: """Plot the projected density of states. @@ -53,8 +53,14 @@ def pdos_plot( orbital_manager = get_orbital_queries_manager(pdos_data) groups_data = reduce_orbital_data( - E_PDOS, groups=groups, orb_dim="orb", spin_dim="spin", sanitize_group=orbital_manager, - group_vars=('color', 'size', 'dash'), groups_dim="group", drop_empty=True, + E_PDOS, + groups=groups, + orb_dim="orb", + spin_dim="spin", + sanitize_group=orbital_manager, + group_vars=("color", "size", "dash"), + groups_dim="group", + drop_empty=True, spin_reduce=np.sum, ) @@ -67,13 +73,22 @@ def pdos_plot( # A PlotterNode gets the processed data and creates abstract actions (backend agnostic) # that should be performed on the figure. The output of this node # must be fed to a figure (backend specific). - final_groups_data = scale_variable(groups_data, var="size", scale=line_scale, default_value=1) - plot_actions = draw_xarray_xy(data=final_groups_data, x=x, y=y, width="size", what=line_mode, dependent_axis=dependent_axis) + final_groups_data = scale_variable( + groups_data, var="size", scale=line_scale, default_value=1 + ) + plot_actions = draw_xarray_xy( + data=final_groups_data, + x=x, + y=y, + width="size", + what=line_mode, + dependent_axis=dependent_axis, + ) return get_figure(backend=backend, plot_actions=plot_actions) -class PdosPlot(OrbitalGroupsPlot): +class PdosPlot(OrbitalGroupsPlot): function = staticmethod(pdos_plot) def split_DOS(self, on="species", only=None, exclude=None, clean=True, **kwargs): @@ -113,5 +128,6 @@ def split_DOS(self, on="species", only=None, exclude=None, clean=True, **kwargs) >>> # be replaced by the value of n. >>> plot.split_DOS(on="n+l", species=["Au"], name="Au $ns") """ - return self.split_groups(on=on, only=only, exclude=exclude, clean=clean, **kwargs) - + return self.split_groups( + on=on, only=only, exclude=exclude, clean=clean, **kwargs + ) diff --git a/src/sisl/viz/plots/tests/test_bands.py b/src/sisl/viz/plots/tests/test_bands.py index 6fb87c3862..5e07d2f06b 100644 --- a/src/sisl/viz/plots/tests/test_bands.py +++ b/src/sisl/viz/plots/tests/test_bands.py @@ -12,17 +12,23 @@ def backend(request): pytest.importorskip(request.param) return request.param -@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"]) + +@pytest.fixture( + scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"] +) def spin(request): return Spin(request.param) + @pytest.fixture(scope="module") def gap(): return 2.5 + @pytest.fixture(scope="module") def bands_data(spin, gap): return BandsData.toy_example(spin=spin, gap=gap) + def test_bands_plot(bands_data, backend): bands_plot(bands_data, backend=backend) diff --git a/src/sisl/viz/plots/tests/test_geometry.py b/src/sisl/viz/plots/tests/test_geometry.py index c5096da93c..0a2b09d3e1 100644 --- a/src/sisl/viz/plots/tests/test_geometry.py +++ b/src/sisl/viz/plots/tests/test_geometry.py @@ -12,16 +12,18 @@ def backend(request): pytest.importorskip(request.param) return request.param + @pytest.fixture(scope="module", params=["x", "xy", "xyz"]) def axes(request): return request.param + @pytest.fixture(scope="module") def geometry(): return sisl.geom.graphene() -def test_geometry_plot(geometry, axes, backend): +def test_geometry_plot(geometry, axes, backend): if axes == "xyz" and backend == "matplotlib": with pytest.raises(NotImplementedError): geometry_plot(geometry, axes=axes, backend=backend) diff --git a/src/sisl/viz/plots/tests/test_grid.py b/src/sisl/viz/plots/tests/test_grid.py index 9fb03631aa..0d54d94cb0 100644 --- a/src/sisl/viz/plots/tests/test_grid.py +++ b/src/sisl/viz/plots/tests/test_grid.py @@ -13,10 +13,12 @@ def backend(request): pytest.importorskip(request.param) return request.param + @pytest.fixture(scope="module", params=["x", "xy", "xyz"]) def axes(request): return request.param + @pytest.fixture(scope="module") def grid(): geometry = sisl.geom.graphene() @@ -25,8 +27,8 @@ def grid(): grid.grid[:] = np.linspace(0, 1000, 1000).reshape(10, 10, 10) return grid -def test_grid_plot(grid, axes, backend): +def test_grid_plot(grid, axes, backend): if axes == "xyz" and backend == "matplotlib": with pytest.raises(NotImplementedError): grid_plot(grid, axes=axes, backend=backend) diff --git a/src/sisl/viz/plots/tests/test_pdos.py b/src/sisl/viz/plots/tests/test_pdos.py index 2b9298b004..5f3083422c 100644 --- a/src/sisl/viz/plots/tests/test_pdos.py +++ b/src/sisl/viz/plots/tests/test_pdos.py @@ -12,13 +12,18 @@ def backend(request): pytest.importorskip(request.param) return request.param -@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"]) + +@pytest.fixture( + scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"] +) def spin(request): return Spin(request.param) + @pytest.fixture(scope="module") def pdos_data(spin): return PDOSData.toy_example(spin=spin) + def test_pdos_plot(pdos_data, backend): pdos_plot(pdos_data, backend=backend) diff --git a/src/sisl/viz/plotters/__init__.py b/src/sisl/viz/plotters/__init__.py index 9a5ac8cc0b..b5fc138916 100644 --- a/src/sisl/viz/plotters/__init__.py +++ b/src/sisl/viz/plotters/__init__.py @@ -1,3 +1,3 @@ """Functions that generate plot actions to be passed to figures.""" -from . import plot_actions \ No newline at end of file +from . import plot_actions diff --git a/src/sisl/viz/plotters/cell.py b/src/sisl/viz/plotters/cell.py index 0ed597f7e0..a98493df97 100644 --- a/src/sisl/viz/plotters/cell.py +++ b/src/sisl/viz/plotters/cell.py @@ -11,6 +11,7 @@ def get_ndim(axes: Axes) -> int: return len(axes) + def get_z(ndim: int) -> Literal["z", False]: if ndim == 3: z = "z" @@ -18,17 +19,33 @@ def get_z(ndim: int) -> Literal["z", False]: z = False return z -def cell_plot_actions(cell: CellLike = None, show_cell: Literal[False, "box", "axes"] = "box", axes=["x", "y", "z"], - name: str = "Unit cell", cell_style={}, dataaxis_1d=None): + +def cell_plot_actions( + cell: CellLike = None, + show_cell: Literal[False, "box", "axes"] = "box", + axes=["x", "y", "z"], + name: str = "Unit cell", + cell_style={}, + dataaxis_1d=None, +): if show_cell == False: cell_plottings = [] else: cell_ds = gen_cell_dataset(cell) cell_lines = cell_to_lines(cell_ds, show_cell, cell_style) - projected_cell_lines = project_to_axes(cell_lines, axes=axes, dataaxis_1d=dataaxis_1d) - + projected_cell_lines = project_to_axes( + cell_lines, axes=axes, dataaxis_1d=dataaxis_1d + ) + ndim = get_ndim(axes) z = get_z(ndim) - cell_plottings = draw_xarray_xy(data=projected_cell_lines, x="x", y="y", z=z, set_axequal=ndim > 1, name=name) - - return cell_plottings \ No newline at end of file + cell_plottings = draw_xarray_xy( + data=projected_cell_lines, + x="x", + y="y", + z=z, + set_axequal=ndim > 1, + name=name, + ) + + return cell_plottings diff --git a/src/sisl/viz/plotters/grid.py b/src/sisl/viz/plotters/grid.py index 40fd413175..5960c835d5 100644 --- a/src/sisl/viz/plotters/grid.py +++ b/src/sisl/viz/plotters/grid.py @@ -3,27 +3,36 @@ def draw_grid(data, isos=[], colorscale=None, crange=None, cmid=None, smooth=False): - to_plot = [] - + ndim = data.ndim if ndim == 1: - to_plot.append( - plot_actions.draw_line(x=data.x, y=data.values) - ) + to_plot.append(plot_actions.draw_line(x=data.x, y=data.values)) elif ndim == 2: transposed = data.transpose("y", "x") cmin, cmax = crange if crange is not None else (None, None) to_plot.append( - plot_actions.init_coloraxis(name="grid_color", cmin=cmin, cmax=cmax, cmid=cmid, colorscale=colorscale) + plot_actions.init_coloraxis( + name="grid_color", + cmin=cmin, + cmax=cmax, + cmid=cmid, + colorscale=colorscale, + ) ) - to_plot.append( - plot_actions.draw_heatmap(values=transposed.values, x=data.x, y=data.y, name="HEAT", zsmooth="best" if smooth else False, coloraxis="grid_color") + plot_actions.draw_heatmap( + values=transposed.values, + x=data.x, + y=data.y, + name="HEAT", + zsmooth="best" if smooth else False, + coloraxis="grid_color", + ) ) dx = data.x[1] - data.x[0] @@ -31,26 +40,20 @@ def draw_grid(data, isos=[], colorscale=None, crange=None, cmid=None, smooth=Fal iso_lines = get_isos(transposed, isos) for iso_line in iso_lines: - iso_line['line'] = { + iso_line["line"] = { "color": iso_line.pop("color", None), "opacity": iso_line.pop("opacity", None), "width": iso_line.pop("width", None), - **iso_line.get("line", {}) + **iso_line.get("line", {}), } - to_plot.append( - plot_actions.draw_line(**iso_line) - ) + to_plot.append(plot_actions.draw_line(**iso_line)) elif ndim == 3: isosurfaces = get_isos(data, isos) - + for isosurface in isosurfaces: - to_plot.append( - plot_actions.draw_mesh_3D(**isosurface) - ) - + to_plot.append(plot_actions.draw_mesh_3D(**isosurface)) + if ndim > 1: - to_plot.append( - plot_actions.set_axes_equal() - ) + to_plot.append(plot_actions.set_axes_equal()) - return to_plot \ No newline at end of file + return to_plot diff --git a/src/sisl/viz/plotters/plot_actions.py b/src/sisl/viz/plotters/plot_actions.py index 69f224f60a..13d5f279d9 100644 --- a/src/sisl/viz/plotters/plot_actions.py +++ b/src/sisl/viz/plotters/plot_actions.py @@ -9,14 +9,15 @@ def _register_actions(figure_cls: Type[Figure]): - # Take all actions possible from the Figure class module = sys.modules[__name__] - actions = inspect.getmembers(figure_cls, predicate=lambda x: inspect.isfunction(x) and not x.__name__.startswith("_")) + actions = inspect.getmembers( + figure_cls, + predicate=lambda x: inspect.isfunction(x) and not x.__name__.startswith("_"), + ) for name, function in actions: - sig = inspect.signature(function) @functools.wraps(function) @@ -25,15 +26,20 @@ def a(*args, __method_name__=function.__name__, **kwargs): a.__signature__ = sig.replace(parameters=list(sig.parameters.values())[1:]) a.__module__ = module - + setattr(module, name, a) + _register_actions(Figure) -def combined(*plotters, - composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = None, - provided_list: bool = False, - **kwargs + +def combined( + *plotters, + composite_method: Optional[ + Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"] + ] = None, + provided_list: bool = False, + **kwargs, ): if provided_list: plotters = plotters[0] @@ -41,5 +47,5 @@ def combined(*plotters, return { "composite_method": composite_method, "plot_actions": plotters, - "init_kwargs": kwargs - } \ No newline at end of file + "init_kwargs": kwargs, + } diff --git a/src/sisl/viz/plotters/tests/test_xarray.py b/src/sisl/viz/plotters/tests/test_xarray.py index bbc74ef648..3e9d9e9459 100644 --- a/src/sisl/viz/plotters/tests/test_xarray.py +++ b/src/sisl/viz/plotters/tests/test_xarray.py @@ -7,17 +7,16 @@ def test_empty_dataset(): - - ds = xr.Dataset({"x": ("dim", []), "y": ("dim", [])}) + ds = xr.Dataset({"x": ("dim", []), "y": ("dim", [])}) drawings = draw_xarray_xy(ds, x="x", y="y") assert isinstance(drawings, list) assert len(drawings) == 0 -def test_empty_dataarray(): - arr = xr.DataArray([], name="values", dims=['x']) +def test_empty_dataarray(): + arr = xr.DataArray([], name="values", dims=["x"]) drawings = draw_xarray_xy(arr, x="x") diff --git a/src/sisl/viz/plotters/xarray.py b/src/sisl/viz/plotters/xarray.py index c5922eba4f..167e6b3313 100644 --- a/src/sisl/viz/plotters/xarray.py +++ b/src/sisl/viz/plotters/xarray.py @@ -7,25 +7,30 @@ import sisl.viz.plotters.plot_actions as plot_actions from sisl.messages import info -#from sisl.viz.nodes.processors.grid import get_isos +# from sisl.viz.nodes.processors.grid import get_isos + def _process_xarray_data(data, x=None, y=None, z=False, style={}): axes = {"x": x, "y": y} if z is not False: axes["z"] = z - + ndim = len(axes) # Normalize data to a Dataset if isinstance(data, DataArray): if np.all([ax is None for ax in axes.values()]): - raise ValueError("You have to provide either x or y (or z if it is not False) (one needs to be the fixed variable).") + raise ValueError( + "You have to provide either x or y (or z if it is not False) (one needs to be the fixed variable)." + ) axes = {k: v or data.name for k, v in axes.items()} data = data.to_dataset(name=data.name) else: if np.any([ax is None for ax in axes.values()]): - raise ValueError("Since you provided a Dataset, you have to provide both x and y (and z if it is not False).") - + raise ValueError( + "Since you provided a Dataset, you have to provide both x and y (and z if it is not False)." + ) + data_axis = None fixed_axes = {} # Check, for each axis, if it is uni dimensional (in which case we add it to the fixed axes dictionary) @@ -57,19 +62,21 @@ def _process_xarray_data(data, x=None, y=None, z=False, style={}): for key, value in style.items(): if value in data: style_dims = style_dims.union(set(data[value].dims)) - + extra_style_dims = style_dims - set(data[data_var].dims) if extra_style_dims: - data = data.stack(extra_style_dim=extra_style_dims).transpose('extra_style_dim', ...) + data = data.stack(extra_style_dim=extra_style_dims).transpose( + "extra_style_dim", ... + ) if data[data_var].shape[0] == 0: return None, None, None, None, None - + if len(data[data_var].shape) == 1: data = data.expand_dims(dim={"fake_dim": [0]}, axis=0) # We have to flatten all the dimensions that will not be represented as an axis, # since we will just iterate over them. - dims_to_stack = data[data_var].dims[:-len(last_dims)] + dims_to_stack = data[data_var].dims[: -len(last_dims)] data = data.stack(iterate_dim=dims_to_stack).transpose("iterate_dim", ...) styles = {} @@ -83,35 +90,59 @@ def _process_xarray_data(data, x=None, y=None, z=False, style={}): fixed_coords = {} for ax_key, fixed_axis in fixed_axes.items(): - fixed_coord = data[fixed_axis] + fixed_coord = data[fixed_axis] if "iterate_dim" in fixed_coord.dims: # This is if fixed_coord was a variable of the dataset, which possibly has # gotten the extra iterate_dim added. fixed_coord = fixed_coord.isel(iterate_dim=0) fixed_coords[ax_key] = fixed_coord - #info(f"{self} variables: \n\t- Fixed: {fixed_axes}\n\t- Data axis: {data_axis}\n\t") + # info(f"{self} variables: \n\t- Fixed: {fixed_axes}\n\t- Data axis: {data_axis}\n\t") return plot_data, fixed_coords, styles, data_axis, axes -def draw_xarray_xy(data, x=None, y=None, z=False, color="color", width="width", dash="dash", opacity="opacity", name="", colorscale=None, - what: typing.Literal["line", "scatter", "balls", "area_line", "arrows", "none"] = "line", + +def draw_xarray_xy( + data, + x=None, + y=None, + z=False, + color="color", + width="width", + dash="dash", + opacity="opacity", + name="", + colorscale=None, + what: typing.Literal[ + "line", "scatter", "balls", "area_line", "arrows", "none" + ] = "line", dependent_axis: typing.Optional[typing.Literal["x", "y"]] = None, - set_axrange=False, set_axequal=False + set_axrange=False, + set_axequal=False, ): if what == "none": return [] plot_data, fixed_coords, styles, data_axis, axes = _process_xarray_data( - data, x=x, y=y, z=z, style={"color": color, "width": width, "opacity": opacity, "dash": dash} + data, + x=x, + y=y, + z=z, + style={"color": color, "width": width, "opacity": opacity, "dash": dash}, ) if plot_data is None: return [] to_plot = _draw_xarray_lines( - data=plot_data, style=styles, fixed_coords=fixed_coords, data_axis=data_axis, colorscale=colorscale, what=what, name=name, - dependent_axis=dependent_axis + data=plot_data, + style=styles, + fixed_coords=fixed_coords, + data_axis=data_axis, + colorscale=colorscale, + what=what, + name=name, + dependent_axis=dependent_axis, ) if set_axequal: @@ -132,14 +163,17 @@ def draw_xarray_xy(data, x=None, y=None, z=False, color="color", width="width", if set_axrange: axis["range"] = (float(ax.min()), float(ax.max())) - + axis.update(ax.attrs.get("axis", {})) to_plot.append(plot_actions.set_axis(axis=key, **axis)) return to_plot -def _draw_xarray_lines(data, style, fixed_coords, data_axis, colorscale, what, name="", dependent_axis=None): + +def _draw_xarray_lines( + data, style, fixed_coords, data_axis, colorscale, what, name="", dependent_axis=None +): # Initialize actions list to_plot = [] @@ -150,7 +184,9 @@ def _draw_xarray_lines(data, style, fixed_coords, data_axis, colorscale, what, n lines_style[key] = style.get(key) if lines_style[key] is not None: - extra_style_dims = extra_style_dims or "extra_style_dim" in lines_style[key].dims + extra_style_dims = ( + extra_style_dims or "extra_style_dim" in lines_style[key].dims + ) # If some style is constant, just repeat it. if lines_style[key] is None or "iterate_dim" not in lines_style[key].dims: lines_style[key] = itertools.repeat(lines_style[key]) @@ -159,24 +195,29 @@ def _draw_xarray_lines(data, style, fixed_coords, data_axis, colorscale, what, n # use a special drawing function. If we have to draw lines with multiple widths # we also need to use a special function. line_kwargs = {} - if isinstance(lines_style['color'], itertools.repeat): - color_value = next(lines_style['color']) + if isinstance(lines_style["color"], itertools.repeat): + color_value = next(lines_style["color"]) else: - color_value = lines_style['color'] + color_value = lines_style["color"] - if isinstance(lines_style['width'], itertools.repeat): - width_value = next(lines_style['width']) + if isinstance(lines_style["width"], itertools.repeat): + width_value = next(lines_style["width"]) else: - width_value = lines_style['width'] + width_value = lines_style["width"] if isinstance(color_value, DataArray) and (data.dims[-1] in color_value.dims): color = color_value if color.dtype in (int, float): coloraxis_name = f"{color.name}_{name}" if name else color.name to_plot.append( - plot_actions.init_coloraxis(name=coloraxis_name, cmin=color.values.min(), cmax=color.values.max(), colorscale=colorscale) + plot_actions.init_coloraxis( + name=coloraxis_name, + cmin=color.values.min(), + cmax=color.values.max(), + colorscale=colorscale, + ) ) - line_kwargs = {'coloraxis': coloraxis_name} + line_kwargs = {"coloraxis": coloraxis_name} drawing_function_name = f"draw_multicolor_{what}" elif isinstance(width_value, DataArray) and (data.dims[-1] in width_value.dims): drawing_function_name = f"draw_multisize_{what}" @@ -187,34 +228,39 @@ def _draw_xarray_lines(data, style, fixed_coords, data_axis, colorscale, what, n if len(fixed_coords) == 2: to_plot.append(plot_actions.init_3D()) drawing_function_name += "_3D" - + _drawing_function = getattr(plot_actions, drawing_function_name) if what in ("scatter", "balls"): + def drawing_function(*args, **kwargs): marker = kwargs.pop("line") - marker['size'] = marker.pop("width") + marker["size"] = marker.pop("width") + + to_plot.append(_drawing_function(*args, marker=marker, **kwargs)) - to_plot.append( - _drawing_function(*args, marker=marker, **kwargs) - ) elif what == "area_line": + def drawing_function(*args, **kwargs): to_plot.append( _drawing_function(*args, dependent_axis=dependent_axis, **kwargs) ) + else: + def drawing_function(*args, **kwargs): - to_plot.append( - _drawing_function(*args, **kwargs) - ) + to_plot.append(_drawing_function(*args, **kwargs)) # Define the iterator over lines, containing both values and styles - iterator = zip(data, - lines_style['color'], lines_style['width'], lines_style['opacity'], lines_style['dash'] + iterator = zip( + data, + lines_style["color"], + lines_style["width"], + lines_style["opacity"], + lines_style["dash"], ) fixed_coords_values = {k: arr.values for k, arr in fixed_coords.items()} - + single_line = len(data.iterate_dim) == 1 if name in data.iterate_dim.coords: name_prefix = "" @@ -223,7 +269,6 @@ def drawing_function(*args, **kwargs): # Now just iterate over each line and plot it. for values, *styles in iterator: - names = values.iterate_dim.values[()] if name in values.iterate_dim.coords: line_name = f"{name_prefix}{values.iterate_dim.coords[name].values[()]}" @@ -243,7 +288,12 @@ def drawing_function(*args, **kwargs): parsed_styles.append(style) line_color, line_width, line_opacity, line_dash = parsed_styles - line_style = {"color": line_color, "width": line_width, "opacity": line_opacity, "dash": line_dash} + line_style = { + "color": line_color, + "width": line_width, + "opacity": line_opacity, + "dash": line_dash, + } line = {**line_style, **line_kwargs} coords = { @@ -258,18 +308,27 @@ def drawing_function(*args, **kwargs): if v is None or v.ndim == 0: line_style[k] = itertools.repeat(v) - for l_color, l_width, l_opacity, l_dash in zip(line_style['color'], line_style['width'], line_style['opacity'], line_style['dash']): - line_style = {"color": l_color, "width": l_width, "opacity": l_opacity, "dash": l_dash} + for l_color, l_width, l_opacity, l_dash in zip( + line_style["color"], + line_style["width"], + line_style["opacity"], + line_style["dash"], + ): + line_style = { + "color": l_color, + "width": l_width, + "opacity": l_opacity, + "dash": l_dash, + } drawing_function(**coords, line=line_style, name=line_name) return to_plot - # class PlotterNodeGrid(PlotterXArray): - + # def draw(self, data, isos=[]): - + # ndim = data.ndim # if ndim == 2: @@ -291,9 +350,9 @@ def drawing_function(*args, **kwargs): # self.draw_line(**iso_line) # elif ndim == 3: # isosurfaces = get_isos(data, isos) - + # for isosurface in isosurfaces: # self.draw_mesh_3D(**isosurface) - - -# self.set_axes_equal() \ No newline at end of file + + +# self.set_axes_equal() diff --git a/src/sisl/viz/plotutils.py b/src/sisl/viz/plotutils.py index e66bc3b0d6..7f05955c6d 100644 --- a/src/sisl/viz/plotutils.py +++ b/src/sisl/viz/plotutils.py @@ -10,11 +10,13 @@ try: from pathos.pools import ProcessPool as Pool + pathos_avail = True except Exception: pathos_avail = False try: import tqdm + tqdm_avail = True except Exception: tqdm_avail = False @@ -25,22 +27,34 @@ from sisl.io.sile import get_sile_rules, get_siles from sisl.messages import info -__all__ = ["running_in_notebook", "check_widgets", - "get_plot_classes", "get_plotable_siles", "get_plotable_variables", - "get_session_classes", "get_avail_presets", - "get_nested_key", "modify_nested_dict", "dictOfLists2listOfDicts", - "get_avail_presets", "random_color", - "load", "find_files", "find_plotable_siles", - "shift_trace", "normalize_trace", "swap_trace_axes" +__all__ = [ + "running_in_notebook", + "check_widgets", + "get_plot_classes", + "get_plotable_siles", + "get_plotable_variables", + "get_session_classes", + "get_avail_presets", + "get_nested_key", + "modify_nested_dict", + "dictOfLists2listOfDicts", + "get_avail_presets", + "random_color", + "load", + "find_files", + "find_plotable_siles", + "shift_trace", + "normalize_trace", + "swap_trace_axes", ] -#------------------------------------- +# ------------------------------------- # Ipython -#------------------------------------- +# ------------------------------------- def running_in_notebook(): - """ Finds out whether the code is being run on a notebook. + """Finds out whether the code is being run on a notebook. Returns -------- @@ -48,13 +62,13 @@ def running_in_notebook(): whether the code is running in a notebook """ try: - return get_ipython().__class__.__name__ == 'ZMQInteractiveShell' + return get_ipython().__class__.__name__ == "ZMQInteractiveShell" except NameError: return False def check_widgets(): - """ Checks if some jupyter notebook widgets are there. + """Checks if some jupyter notebook widgets are there. This will be helpful to know how the figures should be displayed. @@ -67,42 +81,48 @@ def check_widgets(): import subprocess widgets = { - 'plotly_avail': False, - 'plotly_error': False, - 'events_avail': False, - 'events_error': False + "plotly_avail": False, + "plotly_error": False, + "events_avail": False, + "events_error": False, } - out, err = subprocess.Popen(['jupyter', 'nbextension', 'list'], stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate() + out, err = subprocess.Popen( + ["jupyter", "nbextension", "list"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate() out = str(out) err = str(err) - if 'plotlywidget' in out: - widgets['plotly_avail'] = True - if 'plotlywidget' in err: - widgets['plotly_error'] = True + if "plotlywidget" in out: + widgets["plotly_avail"] = True + if "plotlywidget" in err: + widgets["plotly_error"] = True - if 'ipyevents' in out: + if "ipyevents" in out: try: import ipyevents - widgets['events_avail'] = True + + widgets["events_avail"] = True except Exception: pass - if 'ipyevents' in err: - widgets['events_error'] = True + if "ipyevents" in err: + widgets["events_error"] = True - widgets['plotly'] = widgets['plotly_avail'] and not widgets['plotly_error'] - widgets['events'] = widgets['events_avail'] and not widgets['events_error'] + widgets["plotly"] = widgets["plotly_avail"] and not widgets["plotly_error"] + widgets["events"] = widgets["events_avail"] and not widgets["events_error"] return widgets -#------------------------------------- + +# ------------------------------------- # Informative -#------------------------------------- +# ------------------------------------- def get_plot_classes(): - """ This method returns all the plot subclasses, even the nested ones. + """This method returns all the plot subclasses, even the nested ones. Returns --------- @@ -112,7 +132,6 @@ def get_plot_classes(): from . import Plot def get_all_subclasses(cls): - all_subclasses = [] for Subclass in cls.__subclasses__(): @@ -122,11 +141,11 @@ def get_all_subclasses(cls): return all_subclasses - return sorted(get_all_subclasses(Plot), key = lambda clss: clss.plot_name()) + return sorted(get_all_subclasses(Plot), key=lambda clss: clss.plot_name()) def get_plotable_siles(rules=False): - """ Gets the subset of siles that are plotable. + """Gets the subset of siles that are plotable. Returns --------- @@ -142,7 +161,7 @@ def get_plotable_siles(rules=False): def get_plotable_variables(variables): - """ Retrieves all plotable variables that are in the global scope. + """Retrieves all plotable variables that are in the global scope. Examples ----------- @@ -165,7 +184,6 @@ def get_plotable_variables(variables): plotables = {} for vname, obj in list(variables.items()): - if vname.startswith("_"): continue @@ -178,7 +196,7 @@ def get_plotable_variables(variables): def get_avail_presets(): - """ Gets the names of the currently available presets. + """Gets the names of the currently available presets. Returns --------- @@ -189,13 +207,14 @@ def get_avail_presets(): return list(PRESETS.keys()) -#------------------------------------- + +# ------------------------------------- # Python helpers -#------------------------------------- +# ------------------------------------- def get_nested_key(obj, nestedKey, separator="."): - """ Gets a nested key from a dictionary using a given separator. + """Gets a nested key from a dictionary using a given separator. Parameters -------- @@ -231,7 +250,7 @@ def get_nested_key(obj, nestedKey, separator="."): def modify_nested_dict(obj, nestedKey, val, separator="."): - """ Use it to modify a nested dictionary with ease. + """Use it to modify a nested dictionary with ease. It modifies the dictionary itself, does not return anything. @@ -243,7 +262,7 @@ def modify_nested_dict(obj, nestedKey, val, separator="."): The key to modify. See the separator argument for how it should look like. The function will work too if this is a simple key, without any nesting - val: + val: The new value to give to the target key. separator: str, optional (".") It defines how hierarchy is indicated in the provided key. @@ -271,7 +290,7 @@ def modify_nested_dict(obj, nestedKey, val, separator="."): def dictOfLists2listOfDicts(dictOfLists): - """ Converts a dictionary of lists to a list of dictionaries. + """Converts a dictionary of lists to a list of dictionaries. The example will make it quite clear. @@ -294,11 +313,19 @@ def dictOfLists2listOfDicts(dictOfLists): return [dict(zip(dictOfLists, t)) for t in zip(*dictOfLists.values())] -#------------------------------------- +# ------------------------------------- # Filesystem -#------------------------------------- +# ------------------------------------- + -def find_files(root_dir=Path("."), search_string = "*", depth = [0, 0], sort = True, sort_func = None, case_insensitive=False): +def find_files( + root_dir=Path("."), + search_string="*", + depth=[0, 0], + sort=True, + sort_func=None, + case_insensitive=False, +): """ Function that finds files (or directories) according to some conditions. @@ -307,12 +334,12 @@ def find_files(root_dir=Path("."), search_string = "*", depth = [0, 0], sort = T root_dir: str or Path, optional Path of the directory from which the search will start. search_string: str, optional - This is the string that will be passed to glob.glob() to find files or directories. + This is the string that will be passed to glob.glob() to find files or directories. It works mostly like bash, so you can use wildcards, for example. depth: array-like of length 2 or int, optional If it is an array: - It will specify the limits of the search. + It will specify the limits of the search. For example, depth = [1,3] will make the function search for the search_string from 1 to 3 directories deep from root_dir. (0 depth means to look for files in the root_dir) @@ -340,7 +367,12 @@ def find_files(root_dir=Path("."), search_string = "*", depth = [0, 0], sort = T root_dir = Path(root_dir) if case_insensitive: - search_string = "".join([f"[{char.upper()}{char.lower()}]" if char.isalpha() else char for char in search_string]) + search_string = "".join( + [ + f"[{char.upper()}{char.lower()}]" if char.isalpha() else char + for char in search_string + ] + ) files = [] for depth in range(depth[0], depth[1] + 1): @@ -356,7 +388,7 @@ def find_files(root_dir=Path("."), search_string = "*", depth = [0, 0], sort = T def find_plotable_siles(dir_path=None, depth=0): - """ Spans the filesystem to look for files that are registered as plotables. + """Spans the filesystem to look for files that are registered as plotables. Parameters ----------- @@ -368,7 +400,7 @@ def find_plotable_siles(dir_path=None, depth=0): If it is an array: - It will specify the limits of the search. + It will specify the limits of the search. For example, depth = [1,3] will make the function search for the searchString from 1 to 3 directories deep from root_dir. (0 depth means to look for files in the root_dir) @@ -394,12 +426,13 @@ def find_plotable_siles(dir_path=None, depth=0): return files -#------------------------------------- +# ------------------------------------- # Colors -#------------------------------------- +# ------------------------------------- + def random_color(): - """ Returns a random color in hex format + """Returns a random color in hex format Returns -------- @@ -407,11 +440,12 @@ def random_color(): the color in HEX format """ import random - return "#"+"%06x" % random.randint(0, 0xFFFFFF) + + return "#" + "%06x" % random.randint(0, 0xFFFFFF) def values_to_colors(values, scale): - """ Maps an array of numbers to colors using a colorscale. + """Maps an array of numbers to colors using a colorscale. Parameters ----------- @@ -436,11 +470,17 @@ def values_to_colors(values, scale): v_min = np.min(values) values = (values - v_min) / (np.max(values) - v_min) - scale_colors = plotly.colors.convert_colors_to_same_type(scale, colortype="tuple")[0] + scale_colors = plotly.colors.convert_colors_to_same_type(scale, colortype="tuple")[ + 0 + ] if not scale_colors and isinstance(scale, str): - scale_colors = plotly.colors.convert_colors_to_same_type(scale[0].upper() + scale[1:], colortype="tuple")[0] + scale_colors = plotly.colors.convert_colors_to_same_type( + scale[0].upper() + scale[1:], colortype="tuple" + )[0] - cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my color map", scale_colors) + cmap = matplotlib.colors.LinearSegmentedColormap.from_list( + "my color map", scale_colors + ) return plotly.colors.convert_colors_to_same_type([cmap(c) for c in values])[0] diff --git a/src/sisl/viz/processors/__init__.py b/src/sisl/viz/processors/__init__.py index 94016c4463..1e3e70de3d 100644 --- a/src/sisl/viz/processors/__init__.py +++ b/src/sisl/viz/processors/__init__.py @@ -2,4 +2,4 @@ # from .fatbands import * # from .pdos import * # from .geometry import * -# from .grid import * \ No newline at end of file +# from .grid import * diff --git a/src/sisl/viz/processors/atom.py b/src/sisl/viz/processors/atom.py index 568a8446c8..8836d69ab9 100644 --- a/src/sisl/viz/processors/atom.py +++ b/src/sisl/viz/processors/atom.py @@ -16,15 +16,23 @@ class AtomsGroup(Group, total=False): atoms: Any reduce_func: Optional[Callable] -def reduce_atom_data(atom_data: Union[DataArray, Dataset], groups: Sequence[AtomsGroup], geometry: Optional[Geometry] = None, - reduce_func: Callable = np.mean, atom_dim: str = "atom", groups_dim: str = "group", - sanitize_group: Callable = lambda x: x, group_vars: Optional[Sequence[str]] = None, - drop_empty: bool = False, fill_empty: Any = 0. + +def reduce_atom_data( + atom_data: Union[DataArray, Dataset], + groups: Sequence[AtomsGroup], + geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, + atom_dim: str = "atom", + groups_dim: str = "group", + sanitize_group: Callable = lambda x: x, + group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, + fill_empty: Any = 0.0, ) -> Union[DataArray, Dataset]: """Groups contributions of atoms into a new dimension. Given an xarray object containing atom information and the specification of groups of atoms, this function - computes the total contribution for each group of atoms. It therefore removes the atoms dimension and + computes the total contribution for each group of atoms. It therefore removes the atoms dimension and creates a new one to account for the groups. Parameters @@ -39,7 +47,7 @@ def reduce_atom_data(atom_data: Union[DataArray, Dataset], groups: Sequence[Atom If not provided, it will be searched in the ``geometry`` attribute of the ``atom_data`` object. reduce_func : Callable, optional The function that will compute the reduction along the atoms dimension once the selection is done. - This could be for example ``numpy.mean`` or ``numpy.sum``. + This could be for example ``numpy.mean`` or ``numpy.sum``. Notice that this will only be used in case the group specification doesn't specify a particular function in its "reduce_func" field, which will take preference. spin_reduce: Callable, optional @@ -67,31 +75,41 @@ def reduce_atom_data(atom_data: Union[DataArray, Dataset], groups: Sequence[Atom geometry = atom_data.attrs.get("geometry") if geometry is None: + def _sanitize_group(group): group = group.copy() group = sanitize_group(group) - atoms = group['atoms'] + atoms = group["atoms"] try: - group['atoms'] = np.array(atoms, dtype=int) + group["atoms"] = np.array(atoms, dtype=int) assert atoms.ndim == 1 except: - raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't" - f" convert the provided atom selection ({atoms}) to an array of integers.") + raise SislError( + "A geometry was neither provided nor found in the xarray object. Therefore we can't" + f" convert the provided atom selection ({atoms}) to an array of integers." + ) - group['selector'] = group['atoms'] + group["selector"] = group["atoms"] return group + else: + def _sanitize_group(group): group = group.copy() group = sanitize_group(group) group["atoms"] = geometry._sanitize_atoms(group["atoms"]) - group['selector'] = group['atoms'] + group["selector"] = group["atoms"] return group - return group_reduce( - data=atom_data, groups=groups, reduce_dim=atom_dim, reduce_func=reduce_func, - groups_dim=groups_dim, sanitize_group=_sanitize_group, group_vars=group_vars, - drop_empty=drop_empty, fill_empty=fill_empty - ) \ No newline at end of file + data=atom_data, + groups=groups, + reduce_dim=atom_dim, + reduce_func=reduce_func, + groups_dim=groups_dim, + sanitize_group=_sanitize_group, + group_vars=group_vars, + drop_empty=drop_empty, + fill_empty=fill_empty, + ) diff --git a/src/sisl/viz/processors/axes.py b/src/sisl/viz/processors/axes.py index e37e6797d0..7b12a13e63 100644 --- a/src/sisl/viz/processors/axes.py +++ b/src/sisl/viz/processors/axes.py @@ -14,7 +14,7 @@ def sanitize_axis(ax) -> Union[str, int, np.ndarray]: ax = ax.replace("0", "a").replace("1", "b").replace("2", "c") ax = ax.lower().replace("+", "") elif isinstance(ax, int): - ax = 'abc'[ax] + ax = "abc"[ax] elif isinstance(ax, (list, tuple)): ax = np.array(ax) @@ -26,16 +26,22 @@ def sanitize_axis(ax) -> Union[str, int, np.ndarray]: invalid = ax.shape != (3,) if invalid: - raise ValueError(f"Incorrect axis passed. Axes must be one of [+-]('x', 'y', 'z', 'a', 'b', 'c', '0', '1', '2', 0, 1, 2)" + - " or a numpy array/list/tuple of shape (3, )") + raise ValueError( + f"Incorrect axis passed. Axes must be one of [+-]('x', 'y', 'z', 'a', 'b', 'c', '0', '1', '2', 0, 1, 2)" + + " or a numpy array/list/tuple of shape (3, )" + ) return ax -def sanitize_axes(val: Union[str, Sequence[Union[str, int, np.ndarray]]]) -> List[Union[str, int, np.ndarray]]: + +def sanitize_axes( + val: Union[str, Sequence[Union[str, int, np.ndarray]]] +) -> List[Union[str, int, np.ndarray]]: if isinstance(val, str): val = re.findall("[+-]?[xyzabc012]", val) return [sanitize_axis(ax) for ax in val] + def get_ax_title(ax: Union[Axis, Callable], cartesian_units: str = "Ang") -> str: """Generates the title for a given axis""" if hasattr(ax, "__name__"): @@ -45,15 +51,18 @@ def get_ax_title(ax: Union[Axis, Callable], cartesian_units: str = "Ang") -> str elif not isinstance(ax, str): title = "" elif re.match("[+-]?[xXyYzZ]", ax): - title = f'{ax.upper()} axis [{cartesian_units}]' + title = f"{ax.upper()} axis [{cartesian_units}]" elif re.match("[+-]?[aAbBcC]", ax): - title = f'{ax.upper()} lattice vector' + title = f"{ax.upper()} lattice vector" else: title = ax return title -def axis_direction(ax: Axis, cell: Optional[Union[npt.ArrayLike, Lattice]] = None) -> npt.NDArray[np.float64]: + +def axis_direction( + ax: Axis, cell: Optional[Union[npt.ArrayLike, Lattice]] = None +) -> npt.NDArray[np.float64]: """Returns the vector direction of a given axis. Parameters @@ -76,11 +85,14 @@ def axis_direction(ax: Axis, cell: Optional[Union[npt.ArrayLike, Lattice]] = Non if isinstance(ax, str) and ax[0] == "-": sign = -1 ax = ax[1] - ax = sign * direction(ax, abc=cell, xyz=np.diag([1., 1., 1.])) + ax = sign * direction(ax, abc=cell, xyz=np.diag([1.0, 1.0, 1.0])) return ax -def axes_cross_product(v1: Axis, v2: Axis, cell: Optional[Union[npt.ArrayLike, Lattice]] = None): + +def axes_cross_product( + v1: Axis, v2: Axis, cell: Optional[Union[npt.ArrayLike, Lattice]] = None +): """An enhanced version of the cross product. It is an enhanced version because both vectors accept strings that represent @@ -110,4 +122,3 @@ def axes_cross_product(v1: Axis, v2: Axis, cell: Optional[Union[npt.ArrayLike, L # If the vectors are not abc, we just need to take the cross product. return np.cross(axis_direction(v1, cell), axis_direction(v2, cell)) - diff --git a/src/sisl/viz/processors/bands.py b/src/sisl/viz/processors/bands.py index f06a3cb935..7dbea1475a 100644 --- a/src/sisl/viz/processors/bands.py +++ b/src/sisl/viz/processors/bands.py @@ -11,24 +11,29 @@ def filter_bands( - bands_data: xr.Dataset, - Erange: Optional[Tuple[float, float]] = None, - E0: float = 0, - bands_range: Optional[Tuple[int, int]] = None, - spin: Optional[int] = None + bands_data: xr.Dataset, + Erange: Optional[Tuple[float, float]] = None, + E0: float = 0, + bands_range: Optional[Tuple[int, int]] = None, + spin: Optional[int] = None, ) -> xr.Dataset: filtered_bands = bands_data.copy() # Shift the energies according to the reference energy, while keeping the # attributes (which contain the units, amongst other things) - filtered_bands['E'] = bands_data.E - E0 + filtered_bands["E"] = bands_data.E - E0 continous_bands = filtered_bands.dropna("k", how="all") # Get the bands that matter for the plot if Erange is None: if bands_range is None: # If neither E range or bands_range was provided, we will just plot the 15 bands below and above the fermi level - CB = int(continous_bands.E.where(continous_bands.E <= 0).argmax('band').max()) - bands_range = [int(max(continous_bands["band"].min(), CB - 15)), int(min(continous_bands["band"].max() + 1, CB + 16))] + CB = int( + continous_bands.E.where(continous_bands.E <= 0).argmax("band").max() + ) + bands_range = [ + int(max(continous_bands["band"].min(), CB - 15)), + int(min(continous_bands["band"].max() + 1, CB + 16)), + ] filtered_bands = filtered_bands.sel(band=slice(*bands_range)) continous_bands = filtered_bands.dropna("k", how="all") @@ -36,17 +41,19 @@ def filter_bands( # This is the new Erange # Erange = np.array([float(f'{val:.3f}') for val in [float(continous_bands.E.min() - 0.01), float(continous_bands.E.max() + 0.01)]]) else: - filtered_bands = filtered_bands.where((filtered_bands <= Erange[1]) & (filtered_bands >= Erange[0])).dropna("band", "all") + filtered_bands = filtered_bands.where( + (filtered_bands <= Erange[1]) & (filtered_bands >= Erange[0]) + ).dropna("band", "all") continous_bands = filtered_bands.dropna("k", how="all") # This is the new bands range - #bands_range = [int(continous_bands['band'].min()), int(continous_bands['band'].max())] + # bands_range = [int(continous_bands['band'].min()), int(continous_bands['band'].max())] # Give the filtered bands the same attributes as the full bands filtered_bands.attrs = bands_data.attrs filtered_bands.E.attrs = bands_data.E.attrs - filtered_bands.E.attrs['E0'] = filtered_bands.E.attrs.get('E0', 0) + E0 + filtered_bands.E.attrs["E0"] = filtered_bands.E.attrs.get("E0", 0) + E0 # Let's treat the spin if the user requested it if not isinstance(spin, (int, type(None))): @@ -62,13 +69,14 @@ def filter_bands( return filtered_bands + def style_bands( - bands_data: xr.Dataset, - bands_style: dict = {"color": "black", "width": 1}, - spindown_style: dict = {"color": "blue", "width": 1} + bands_data: xr.Dataset, + bands_style: dict = {"color": "black", "width": 1}, + spindown_style: dict = {"color": "blue", "width": 1}, ) -> xr.Dataset: """Returns the bands dataset, with the style information added to it. - + Parameters ------------ bands_data: xr.Dataset @@ -82,12 +90,12 @@ def style_bands( """ # If the user provided a styler function, apply it. if bands_style.get("styler") is not None: - if callable(bands_style['styler']): - bands_data = bands_style['styler'](data=bands_data) + if callable(bands_style["styler"]): + bands_data = bands_style["styler"](data=bands_data) - # Include default styles in bands_style, only if they are not already + # Include default styles in bands_style, only if they are not already # present in the bands dataset (e.g. because the styler included them) - default_styles = {'color': 'black', 'width': 1, 'opacity': 1} + default_styles = {"color": "black", "width": 1, "opacity": 1} for key in default_styles: if key not in bands_data.data_vars and key not in bands_style: bands_style[key] = default_styles[key] @@ -98,27 +106,32 @@ def style_bands( bands_style[key] = bands_style[key](data=bands_data) # Build the style dataarrays - if 'spin' in bands_data.dims: + if "spin" in bands_data.dims: spindown_style = {**bands_style, **spindown_style} style_arrays = {} - for key in ['color', 'width', 'opacity']: + for key in ["color", "width", "opacity"]: if isinstance(bands_style[key], xr.DataArray): if not isinstance(spindown_style[key], xr.DataArray): down_style = bands_style[key].copy(deep=True) down_style.values[:] = spindown_style[key] spindown_style[key] = down_style - - style_arrays[key] = xr.concat([bands_style[key], spindown_style[key]], dim='spin') + + style_arrays[key] = xr.concat( + [bands_style[key], spindown_style[key]], dim="spin" + ) else: - style_arrays[key] = xr.DataArray([bands_style[key], spindown_style[key]], dims=['spin']) + style_arrays[key] = xr.DataArray( + [bands_style[key], spindown_style[key]], dims=["spin"] + ) else: style_arrays = {} - for key in ['color', 'width', 'opacity']: + for key in ["color", "width", "opacity"]: style_arrays[key] = xr.DataArray(bands_style[key]) # Merge the style arrays with the bands dataset and return the styled dataset return bands_data.assign(style_arrays) + def calculate_gap(bands_data: xr.Dataset) -> dict: bands_E = bands_data.E # Calculate the band gap to store it @@ -134,13 +147,16 @@ def calculate_gap(bands_data: xr.Dataset) -> dict: gap = float(CBbot - VBtop) return { - 'gap': gap, - 'k': (VB["k"].values, CB['k'].values), - 'bands': (VB["band"].values, CB["band"].values), - 'spin': (VB["spin"].values, CB["spin"].values) if bands_data.attrs['spin'].is_polarized else (0, 0), - 'Es': (float(VBtop), float(CBbot)) + "gap": gap, + "k": (VB["k"].values, CB["k"].values), + "bands": (VB["band"].values, CB["band"].values), + "spin": (VB["spin"].values, CB["spin"].values) + if bands_data.attrs["spin"].is_polarized + else (0, 0), + "Es": (float(VBtop), float(CBbot)), } + def sanitize_k(bands_data: xr.Dataset, k: Union[float, str]) -> Optional[float]: """Returns the float value of a k point in the plot. @@ -164,9 +180,12 @@ def sanitize_k(bands_data: xr.Dataset, k: Union[float, str]) -> Optional[float]: try: san_k = float(k) except ValueError: - if 'axis' in bands_data.k.attrs and bands_data.k.attrs['axis'].get('ticktext') is not None: - ticktext = bands_data.k.attrs['axis']['ticktext'] - tickvals = bands_data.k.attrs['axis']['tickvals'] + if ( + "axis" in bands_data.k.attrs + and bands_data.k.attrs["axis"].get("ticktext") is not None + ): + ticktext = bands_data.k.attrs["axis"]["ticktext"] + tickvals = bands_data.k.attrs["axis"]["tickvals"] if k in ticktext: i_tick = ticktext.index(k) san_k = tickvals[i_tick] @@ -177,15 +196,16 @@ def sanitize_k(bands_data: xr.Dataset, k: Union[float, str]) -> Optional[float]: return san_k + def get_gap_coords( - bands_data: xr.Dataset, + bands_data: xr.Dataset, bands: Tuple[int, int], from_k: Union[float, str], to_k: Optional[Union[float, str]] = None, - spin: int = 0 + spin: int = 0, ) -> Tuple[Tuple[float, float], Tuple[float, float]]: """Calculates the coordinates of a gap given some k values. - + Parameters ----------- bands_data: xr.Dataset @@ -222,8 +242,13 @@ def get_gap_coords( ks[i] = sanitize_k(bands_data, val) VB, CB = bands - spin_bands = bands_data.E.sel(spin=spin) if "spin" in bands_data.coords else bands_data.E - Es = [spin_bands.dropna("k", "all").sel(k=k, band=band, method="nearest") for k, band in zip(ks, (VB, CB))] + spin_bands = ( + bands_data.E.sel(spin=spin) if "spin" in bands_data.coords else bands_data.E + ) + Es = [ + spin_bands.dropna("k", "all").sel(k=k, band=band, method="nearest") + for k, band in zip(ks, (VB, CB)) + ] # Get the real values of ks that have been obtained # because we might not have exactly the ks requested ks = tuple(np.ravel(E.k)[0] for E in Es) @@ -231,13 +256,17 @@ def get_gap_coords( return ks, Es + def draw_gaps( - bands_data: xr.Dataset, - gap: bool, gap_info: dict, gap_tol: float, - gap_color: Optional[str], gap_marker: Optional[dict], - direct_gaps_only: bool, - custom_gaps: Sequence[dict], - E_axis: Literal["x", "y"] + bands_data: xr.Dataset, + gap: bool, + gap_info: dict, + gap_tol: float, + gap_color: Optional[str], + gap_marker: Optional[dict], + direct_gaps_only: bool, + custom_gaps: Sequence[dict], + E_axis: Literal["x", "y"], ) -> List[dict]: """Returns the drawing actions to draw gaps. @@ -262,9 +291,9 @@ def draw_gaps( List of custom gaps to draw. Each dict can contain the keys: - "from": the k value where the gap starts. - "to": the k value where the gap ends. If not present, equal to "from". - - "spin": For which spin component do you want to draw the gap + - "spin": For which spin component do you want to draw the gap (has effect only if spin is polarized). Optional. If None and the bands - are polarized, the gap will be drawn for both spin components. + are polarized, the gap will be drawn for both spin components. - "color": Color of the line that draws the gap. Optional. - "marker": Marker specification for the limits of the gap. Optional. E_axis: Literal["x", "y"] @@ -274,8 +303,7 @@ def draw_gaps( # Draw gaps if gap: - - gapKs = [np.atleast_1d(k) for k in gap_info['k']] + gapKs = [np.atleast_1d(k) for k in gap_info["k"]] # Remove "equivalent" gaps def clear_equivalent(ks): @@ -291,21 +319,26 @@ def clear_equivalent(ks): all_gapKs = itertools.product(*[clear_equivalent(ks) for ks in gapKs]) for gap_ks in all_gapKs: - if direct_gaps_only and abs(gap_ks[1] - gap_ks[0]) > gap_tol: continue - ks, Es = get_gap_coords(bands_data, gap_info['bands'], *gap_ks, spin=gap_info.get('spin', [0])[0]) + ks, Es = get_gap_coords( + bands_data, + gap_info["bands"], + *gap_ks, + spin=gap_info.get("spin", [0])[0], + ) name = "Gap" draw_actions.append( - draw_gap(ks, Es, color=gap_color, name=name, marker=gap_marker, E_axis=E_axis) + draw_gap( + ks, Es, color=gap_color, name=name, marker=gap_marker, E_axis=E_axis + ) ) # Draw the custom gaps. These are gaps that do not necessarily represent # the maximum and the minimum of the VB and CB. for custom_gap in custom_gaps: - requested_spin = custom_gap.get("spin", None) if requested_spin is None: requested_spin = [0, 1] @@ -318,23 +351,34 @@ def clear_equivalent(ks): to_k = custom_gap.get("to", from_k) color = custom_gap.get("color", None) name = f"Gap ({from_k}-{to_k})" - ks, Es = get_gap_coords(bands_data, gap_info['bands'], from_k, to_k, spin=spin) + ks, Es = get_gap_coords( + bands_data, gap_info["bands"], from_k, to_k, spin=spin + ) draw_actions.append( - draw_gap(ks, Es, color=color, name=name, marker=custom_gap.get("marker", {}), E_axis=E_axis) + draw_gap( + ks, + Es, + color=color, + name=name, + marker=custom_gap.get("marker", {}), + E_axis=E_axis, + ) ) return draw_actions + def draw_gap( - ks: Tuple[float, float], - Es: Tuple[float, float], - color: Optional[str] = None, marker: dict = {}, - name: str = "Gap", - E_axis: Literal["x", "y"] = "y" + ks: Tuple[float, float], + Es: Tuple[float, float], + color: Optional[str] = None, + marker: dict = {}, + name: str = "Gap", + E_axis: Literal["x", "y"] = "y", ) -> dict: """Returns the drawing action to draw a gap. - + Parameters ------------ ks: tuple of float @@ -357,11 +401,13 @@ def draw_gap( else: raise ValueError(f"E_axis must be either 'x' or 'y', but was {E_axis}") - return plot_actions.draw_line(**{ - **coords, - 'text': [f'Gap: {Es[1] - Es[0]:.3f} eV', ''], - 'name': name, - 'textposition': 'top right', - 'marker': {"size": 7, 'color': color, **marker}, - 'line': {'color': color}, - }) \ No newline at end of file + return plot_actions.draw_line( + **{ + **coords, + "text": [f"Gap: {Es[1] - Es[0]:.3f} eV", ""], + "name": name, + "textposition": "top right", + "marker": {"size": 7, "color": color, **marker}, + "line": {"color": color}, + } + ) diff --git a/src/sisl/viz/processors/cell.py b/src/sisl/viz/processors/cell.py index ba17decd57..b728357e9e 100644 --- a/src/sisl/viz/processors/cell.py +++ b/src/sisl/viz/processors/cell.py @@ -9,10 +9,11 @@ from sisl.lattice import Lattice, LatticeChild -#from ...types import CellLike -#from .coords import project_to_axes, CoordsDataset +# from ...types import CellLike +# from .coords import project_to_axes, CoordsDataset + +# CellDataset = CoordsDataset -#CellDataset = CoordsDataset def is_cartesian_unordered(cell: CellLike, tol: float = 1e-3) -> bool: """Whether a cell has cartesian axes as lattice vectors, regardless of their order. @@ -28,9 +29,16 @@ def is_cartesian_unordered(cell: CellLike, tol: float = 1e-3) -> bool: cell = cell.cell bigger_than_tol = abs(cell) > tol - return bigger_than_tol.sum() == 3 and bigger_than_tol.any(axis=0).all() and bigger_than_tol.any(axis=1).all() + return ( + bigger_than_tol.sum() == 3 + and bigger_than_tol.any(axis=0).all() + and bigger_than_tol.any(axis=1).all() + ) -def is_1D_cartesian(cell: CellLike, coord_ax: Literal["x", "y", "z"], tol: float = 1e-3) -> bool: + +def is_1D_cartesian( + cell: CellLike, coord_ax: Literal["x", "y", "z"], tol: float = 1e-3 +) -> bool: """Whether a cell contains only one vector that contributes only to a given coordinate. That is, one vector follows the direction of the cartesian axis and the other vectors don't @@ -54,24 +62,28 @@ def is_1D_cartesian(cell: CellLike, coord_ax: Literal["x", "y", "z"], tol: float is_1D_cartesian = lattice_vecs.shape[0] == 1 return is_1D_cartesian and (cell[lattice_vecs[0]] > tol).sum() == 1 + def infer_cell_axes(cell: CellLike, axes: List[str], tol: float = 1e-3) -> List[int]: """Returns the indices of the lattice vectors that correspond to the given axes.""" if isinstance(cell, (Lattice, LatticeChild)): cell = cell.cell - + grid_axes = [] for ax in axes: if ax in ("x", "y", "z"): coord_index = "xyz".index(ax) lattice_vecs = np.where(cell[:, coord_index] > tol)[0] if lattice_vecs.shape[0] != 1: - raise ValueError(f"There are {lattice_vecs.shape[0]} lattice vectors that contribute to the {'xyz'[coord_index]} coordinate.") + raise ValueError( + f"There are {lattice_vecs.shape[0]} lattice vectors that contribute to the {'xyz'[coord_index]} coordinate." + ) grid_axes.append(lattice_vecs[0]) else: grid_axes.append("abc".index(ax)) return grid_axes + def gen_cell_dataset(lattice: Union[Lattice, LatticeChild]) -> CellDataset: """Generates a dataset with the vertices of the cell.""" if isinstance(lattice, LatticeChild): @@ -79,23 +91,28 @@ def gen_cell_dataset(lattice: Union[Lattice, LatticeChild]) -> CellDataset: return Dataset( {"xyz": (("a", "b", "c", "axis"), lattice.vertices())}, - coords={"a": [0,1], "b": [0, 1], "c": [0, 1], "axis": [0,1,2]}, - attrs={ - "lattice": lattice - } + coords={"a": [0, 1], "b": [0, 1], "c": [0, 1], "axis": [0, 1, 2]}, + attrs={"lattice": lattice}, ) + class CellStyleSpec(TypedDict): color: Any width: Any opacity: Any + class PartialCellStyleSpec(TypedDict, total=False): color: Any width: Any opacity: Any -def cell_to_lines(cell_data: CellDataset, how: Literal["box", "axes"], cell_style: PartialCellStyleSpec = {}) -> CellDataset: + +def cell_to_lines( + cell_data: CellDataset, + how: Literal["box", "axes"], + cell_style: PartialCellStyleSpec = {}, +) -> CellDataset: """Converts a cell dataset to lines that should be plotted. Parameters @@ -110,35 +127,61 @@ def cell_to_lines(cell_data: CellDataset, how: Literal["box", "axes"], cell_styl Style of the cell lines. A dictionary optionally containing the keys "color", "width" and "opacity". """ - cell_data = cell_data.reindex(a=[0,1,2], b=[0,1,2], c=[0,1,2]) - + cell_data = cell_data.reindex(a=[0, 1, 2], b=[0, 1, 2], c=[0, 1, 2]) + if how == "box": - verts = np.array([ - (0, 0, 0), (0, 1, 0), (1, 1, 0), (1, 1, 1), (0, 1, 1), (0, 1, 0), - (2, 2, 2), - (0, 1, 1), (0, 0, 1), (0, 0, 0), (1, 0, 0), (1, 0, 1), (0, 0, 1), - (2, 2, 2), - (1, 1, 0), (1, 0, 0), - (2, 2, 2), - (1, 1, 1), (1, 0, 1) - ]) + verts = np.array( + [ + (0, 0, 0), + (0, 1, 0), + (1, 1, 0), + (1, 1, 1), + (0, 1, 1), + (0, 1, 0), + (2, 2, 2), + (0, 1, 1), + (0, 0, 1), + (0, 0, 0), + (1, 0, 0), + (1, 0, 1), + (0, 0, 1), + (2, 2, 2), + (1, 1, 0), + (1, 0, 0), + (2, 2, 2), + (1, 1, 1), + (1, 0, 1), + ] + ) elif how == "axes": - verts = np.array([ - (0, 0, 0), (1, 0, 0), (2, 2, 2), - (0, 0, 0), (0, 1, 0), (2, 2, 2), - (0, 0, 0), (0, 0, 1), (2, 2, 2), - ]) + verts = np.array( + [ + (0, 0, 0), + (1, 0, 0), + (2, 2, 2), + (0, 0, 0), + (0, 1, 0), + (2, 2, 2), + (0, 0, 0), + (0, 0, 1), + (2, 2, 2), + ] + ) else: - raise ValueError(f"'how' argument must be either 'box' or 'axes', but got {how}") - - xyz = cell_data.xyz.values[verts[:,0],verts[:,1],verts[:,2]] - - cell_data = cell_data.assign({ - "xyz": (("point_index", "axis"), xyz), - "color": cell_style.get("color"), - "width": cell_style.get("width"), - "opacity": cell_style.get("opacity"), - }) + raise ValueError( + f"'how' argument must be either 'box' or 'axes', but got {how}" + ) + + xyz = cell_data.xyz.values[verts[:, 0], verts[:, 1], verts[:, 2]] + + cell_data = cell_data.assign( + { + "xyz": (("point_index", "axis"), xyz), + "color": cell_style.get("color"), + "width": cell_style.get("width"), + "opacity": cell_style.get("opacity"), + } + ) return cell_data diff --git a/src/sisl/viz/processors/coords.py b/src/sisl/viz/processors/coords.py index 3c77e1f206..7322dbe8e1 100644 --- a/src/sisl/viz/processors/coords.py +++ b/src/sisl/viz/processors/coords.py @@ -13,11 +13,14 @@ from .axes import axes_cross_product, axis_direction, get_ax_title -#from ...types import Axes, CellLike, Axis +# from ...types import Axes, CellLike, Axis CoordsDataset = Dataset -def projected_2Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], xaxis: Axis = "x", yaxis: Axis = "y") -> npt.NDArray[np.float64]: + +def projected_2Dcoords( + cell: CellLike, xyz: npt.NDArray[np.float64], xaxis: Axis = "x", yaxis: Axis = "y" +) -> npt.NDArray[np.float64]: """Moves the 3D positions of the atoms to a 2D supspace. In this way, we can plot the structure from the "point of view" that we want. @@ -33,9 +36,9 @@ def projected_2Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], xaxis: Axis the geometry for which you want the projected coords xyz: array-like of shape (natoms, 3), optional the 3D coordinates that we want to project. - otherwise they are taken from the geometry. + otherwise they are taken from the geometry. xaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional - the direction to be displayed along the X axis. + the direction to be displayed along the X axis. yaxis: {"x", "y", "z", "a", "b", "c"} or array-like of shape 3, optional the direction to be displayed along the X axis. @@ -70,6 +73,7 @@ def projected_2Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], xaxis: Axis return np.dot(xyz, icell.T)[..., coord_indices] + def projected_1Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], axis: Axis = "x"): """ Moves the 3D positions of the atoms to a 2D supspace. @@ -87,7 +91,7 @@ def projected_1Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], axis: Axis the geometry for which you want the projected coords xyz: array-like of shape (natoms, 3), optional the 3D coordinates that we want to project. - otherwise they are taken from the geometry. + otherwise they are taken from the geometry. axis: {"x", "y", "z", "a", "b", "c", "1", "2", "3"} or array-like of shape 3, optional the direction to be displayed along the X axis. nsc: array-like of shape (3, ), optional @@ -104,19 +108,19 @@ def projected_1Dcoords(cell: CellLike, xyz: npt.NDArray[np.float64], axis: Axis cell = cell.cell if isinstance(axis, str) and axis in ("a", "b", "c", "0", "1", "2"): - return projected_2Dcoords(cell, xyz, xaxis=axis, yaxis="a" if axis == "c" else "c")[..., 0] + return projected_2Dcoords( + cell, xyz, xaxis=axis, yaxis="a" if axis == "c" else "c" + )[..., 0] # Get the direction that the axis represents axis = axis_direction(axis, cell) - return xyz.dot(axis/fnorm(axis)) / fnorm(axis) + return xyz.dot(axis / fnorm(axis)) / fnorm(axis) -def coords_depth( - coords_data: CoordsDataset, - axes: Axes -) -> npt.NDArray[np.float64]: + +def coords_depth(coords_data: CoordsDataset, axes: Axes) -> npt.NDArray[np.float64]: """Computes the depth of 3D points as projected in a 2D plane - + Parameters ---------- coords_data: CoordsDataset @@ -125,29 +129,31 @@ def coords_depth( The axes that define the plane where the coordinates are projected. """ cell = _get_cell_from_dataset(coords_data=coords_data) - + depth_vector = axes_cross_product(axes[0], axes[1], cell) depth = project_to_axes(coords_data, axes=[depth_vector]).x.values - + return depth + def sphere( - center: npt.ArrayLike = [0, 0, 0], - r: float = 1, - vertices: int = 10 + center: npt.ArrayLike = [0, 0, 0], r: float = 1, vertices: int = 10 ) -> Dict[str, np.ndarray]: """Computes a mesh defining a sphere.""" - phi, theta = np.mgrid[0.0:np.pi: 1j*vertices, 0.0:2.0*np.pi: 1j*vertices] + phi, theta = np.mgrid[ + 0.0 : np.pi : 1j * vertices, 0.0 : 2.0 * np.pi : 1j * vertices + ] center = np.array(center) phi = np.ravel(phi) theta = np.ravel(theta) - x = center[0] + r*np.sin(phi)*np.cos(theta) - y = center[1] + r*np.sin(phi)*np.sin(theta) - z = center[2] + r*np.cos(phi) + x = center[0] + r * np.sin(phi) * np.cos(theta) + y = center[1] + r * np.sin(phi) * np.sin(theta) + z = center[2] + r * np.cos(phi) + + return {"x": x, "y": y, "z": z} - return {'x': x, 'y': y, 'z': z} def _get_cell_from_dataset(coords_data: CoordsDataset) -> npt.NDArray[np.float64]: cell = coords_data.attrs.get("cell") @@ -156,10 +162,15 @@ def _get_cell_from_dataset(coords_data: CoordsDataset) -> npt.NDArray[np.float64 cell = coords_data.lattice.cell else: cell = coords_data.geometry.cell - + return cell -def projected_1D_data(coords_data: CoordsDataset, axis: Axis = "x", dataaxis_1d: Union[Callable, npt.NDArray, None] = None) -> CoordsDataset: + +def projected_1D_data( + coords_data: CoordsDataset, + axis: Axis = "x", + dataaxis_1d: Union[Callable, npt.NDArray, None] = None, +) -> CoordsDataset: cell = _get_cell_from_dataset(coords_data=coords_data) xyz = coords_data.xyz.values @@ -182,7 +193,13 @@ def projected_1D_data(coords_data: CoordsDataset, axis: Axis = "x", dataaxis_1d: return coords_data -def projected_2D_data(coords_data: CoordsDataset, xaxis: Axis = "x", yaxis: Axis = "y", sort_by_depth: bool = False) -> CoordsDataset: + +def projected_2D_data( + coords_data: CoordsDataset, + xaxis: Axis = "x", + yaxis: Axis = "y", + sort_by_depth: bool = False, +) -> CoordsDataset: cell = _get_cell_from_dataset(coords_data=coords_data) xyz = coords_data.xyz.values @@ -202,6 +219,7 @@ def projected_2D_data(coords_data: CoordsDataset, xaxis: Axis = "x", yaxis: Axis return coords_data + def projected_3D_data(coords_data: CoordsDataset) -> CoordsDataset: x, y, z = np.moveaxis(coords_data.xyz.values, -1, 0) dims = coords_data.xyz.dims[:-1] @@ -210,23 +228,29 @@ def projected_3D_data(coords_data: CoordsDataset) -> CoordsDataset: return coords_data + def project_to_axes( - coords_data: CoordsDataset, axes: Axes, - dataaxis_1d: Optional[Union[npt.ArrayLike, Callable]] = None, + coords_data: CoordsDataset, + axes: Axes, + dataaxis_1d: Optional[Union[npt.ArrayLike, Callable]] = None, sort_by_depth: bool = False, - cartesian_units: str = "Ang" + cartesian_units: str = "Ang", ) -> CoordsDataset: - ndim = len(axes) + ndim = len(axes) if ndim == 3: xaxis, yaxis, zaxis = axes coords_data = projected_3D_data(coords_data) elif ndim == 2: xaxis, yaxis = axes - coords_data = projected_2D_data(coords_data, xaxis=xaxis, yaxis=yaxis, sort_by_depth=sort_by_depth) + coords_data = projected_2D_data( + coords_data, xaxis=xaxis, yaxis=yaxis, sort_by_depth=sort_by_depth + ) elif ndim == 1: xaxis = axes[0] yaxis = dataaxis_1d - coords_data = projected_1D_data(coords_data, axis=xaxis, dataaxis_1d=dataaxis_1d) + coords_data = projected_1D_data( + coords_data, axis=xaxis, dataaxis_1d=dataaxis_1d + ) plot_axes = ["x", "y", "z"][:ndim] @@ -234,7 +258,7 @@ def project_to_axes( coords_data[plot_ax].attrs["axis"] = { "title": get_ax_title(ax, cartesian_units=cartesian_units), } - - coords_data.attrs['ndim'] = ndim + + coords_data.attrs["ndim"] = ndim return coords_data diff --git a/src/sisl/viz/processors/data.py b/src/sisl/viz/processors/data.py index ef33150285..44ff9a9238 100644 --- a/src/sisl/viz/processors/data.py +++ b/src/sisl/viz/processors/data.py @@ -4,22 +4,28 @@ DataInstance = TypeVar("DataInstance", bound=Data) -def accept_data(data: DataInstance, cls: Type[Data], check: bool = True) -> DataInstance: +def accept_data( + data: DataInstance, cls: Type[Data], check: bool = True +) -> DataInstance: if not isinstance(data, cls): - raise TypeError(f"Data must be of type {cls.__name__} and was {type(data).__name__}") - + raise TypeError( + f"Data must be of type {cls.__name__} and was {type(data).__name__}" + ) + if check: data.sanity_check() return data -def extract_data(data: Data, cls: Type[Data], check: bool = True): +def extract_data(data: Data, cls: Type[Data], check: bool = True): if not isinstance(data, cls): - raise TypeError(f"Data must be of type {cls.__name__} and was {type(data).__name__}") - + raise TypeError( + f"Data must be of type {cls.__name__} and was {type(data).__name__}" + ) + if check: data.sanity_check() - return data._data \ No newline at end of file + return data._data diff --git a/src/sisl/viz/processors/eigenstate.py b/src/sisl/viz/processors/eigenstate.py index 942d4b16da..3d5f19a358 100644 --- a/src/sisl/viz/processors/eigenstate.py +++ b/src/sisl/viz/processors/eigenstate.py @@ -5,14 +5,16 @@ import sisl -def get_eigenstate(eigenstate: sisl.EigenstateElectron, i: int) -> sisl.EigenstateElectron: +def get_eigenstate( + eigenstate: sisl.EigenstateElectron, i: int +) -> sisl.EigenstateElectron: """Gets the i-th wavefunction from the eigenstate. It takes into account if the info dictionary has an "index" key, which might be present for example if the eigenstate object does not contain the full set of wavefunctions, to indicate which wavefunctions are present. - + Parameters ---------- eigenstate : sisl.EigenstateElectron @@ -24,19 +26,26 @@ def get_eigenstate(eigenstate: sisl.EigenstateElectron, i: int) -> sisl.Eigensta if "index" in eigenstate.info: wf_i = np.nonzero(eigenstate.info["index"] == i)[0] if len(wf_i) == 0: - raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {eigenstate.info['index']}.") + raise ValueError( + f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {eigenstate.info['index']}." + ) wf_i = wf_i[0] else: max_index = eigenstate.shape[0] if i > max_index: - raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}].") + raise ValueError( + f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}]." + ) wf_i = i return eigenstate[wf_i] -def eigenstate_geometry(eigenstate: sisl.EigenstateElectron, geometry: Optional[sisl.Geometry] = None) -> Union[sisl.Geometry, None]: + +def eigenstate_geometry( + eigenstate: sisl.EigenstateElectron, geometry: Optional[sisl.Geometry] = None +) -> Union[sisl.Geometry, None]: """Returns the geometry associated with the eigenstate. - + Parameters ---------- eigenstate : sisl.EigenstateElectron @@ -49,10 +58,15 @@ def eigenstate_geometry(eigenstate: sisl.EigenstateElectron, geometry: Optional[ geometry = getattr(eigenstate, "parent", None) if geometry is not None and not isinstance(geometry, sisl.Geometry): geometry = getattr(geometry, "geometry", None) - + return geometry -def tile_if_k(geometry: sisl.Geometry, nsc: Tuple[int, int, int], eigenstate: sisl.EigenstateElectron) -> sisl.Geometry: + +def tile_if_k( + geometry: sisl.Geometry, + nsc: Tuple[int, int, int], + eigenstate: sisl.EigenstateElectron, +) -> sisl.Geometry: """Tiles the geometry if the eigenstate does not correspond to gamma. If we are calculating the wavefunction for any point other than gamma, @@ -69,10 +83,12 @@ def tile_if_k(geometry: sisl.Geometry, nsc: Tuple[int, int, int], eigenstate: si eigenstate : sisl.EigenstateElectron The eigenstate for which the wavefunction was calculated. """ - + tiled_geometry = geometry - k = eigenstate.info.get("k", (1, 1, 1) if np.iscomplexobj(eigenstate.state) else (0, 0, 0)) + k = eigenstate.info.get( + "k", (1, 1, 1) if np.iscomplexobj(eigenstate.state) else (0, 0, 0) + ) for ax, sc_i in enumerate(nsc): if k[ax] != 0: @@ -80,7 +96,10 @@ def tile_if_k(geometry: sisl.Geometry, nsc: Tuple[int, int, int], eigenstate: si return tiled_geometry -def get_grid_nsc(nsc: Tuple[int, int, int], eigenstate: sisl.EigenstateElectron) -> Tuple[int, int, int]: + +def get_grid_nsc( + nsc: Tuple[int, int, int], eigenstate: sisl.EigenstateElectron +) -> Tuple[int, int, int]: """Returns the supercell to display once the geometry is tiled. The geometry must be tiled if the eigenstate is not calculated at gamma, @@ -94,13 +113,21 @@ def get_grid_nsc(nsc: Tuple[int, int, int], eigenstate: sisl.EigenstateElectron) eigenstate : sisl.EigenstateElectron The eigenstate for which the wavefunction was calculated. """ - k = eigenstate.info.get("k", (1, 1, 1) if np.iscomplexobj(eigenstate.state) else (0, 0, 0)) + k = eigenstate.info.get( + "k", (1, 1, 1) if np.iscomplexobj(eigenstate.state) else (0, 0, 0) + ) return tuple(nx if kx == 0 else 1 for nx, kx in zip(nsc, k)) -def create_wf_grid(eigenstate: sisl.EigenstateElectron, grid_prec: float = 0.2, grid: Optional[sisl.Grid] = None, geometry: Optional[sisl.Geometry] = None) -> sisl.Grid: + +def create_wf_grid( + eigenstate: sisl.EigenstateElectron, + grid_prec: float = 0.2, + grid: Optional[sisl.Grid] = None, + geometry: Optional[sisl.Geometry] = None, +) -> sisl.Grid: """Creates a grid to display the wavefunction. - + Parameters ---------- eigenstate : sisl.EigenstateElectron @@ -121,7 +148,13 @@ def create_wf_grid(eigenstate: sisl.EigenstateElectron, grid_prec: float = 0.2, return grid -def project_wavefunction(eigenstate: sisl.EigenstateElectron, grid_prec: float = 0.2, grid: Optional[sisl.Grid] = None, geometry: Optional[sisl.Geometry] = None) -> sisl.Grid: + +def project_wavefunction( + eigenstate: sisl.EigenstateElectron, + grid_prec: float = 0.2, + grid: Optional[sisl.Grid] = None, + geometry: Optional[sisl.Geometry] = None, +) -> sisl.Grid: """Projects the wavefunction from an eigenstate into a grid. Parameters @@ -141,11 +174,14 @@ def project_wavefunction(eigenstate: sisl.EigenstateElectron, grid_prec: float = grid = create_wf_grid(eigenstate, grid_prec=grid_prec, grid=grid, geometry=geometry) # Ensure we are dealing with the R gauge - eigenstate.change_gauge('R') + eigenstate.change_gauge("R") # Finally, insert the wavefunction values into the grid. sisl.physics.electron.wavefunction( - eigenstate.state, grid, geometry=geometry, spinor=0, + eigenstate.state, + grid, + geometry=geometry, + spinor=0, ) - return grid \ No newline at end of file + return grid diff --git a/src/sisl/viz/processors/geometry.py b/src/sisl/viz/processors/geometry.py index 89431c01ac..7286e4bb79 100644 --- a/src/sisl/viz/processors/geometry.py +++ b/src/sisl/viz/processors/geometry.py @@ -17,7 +17,7 @@ from ..data_sources.atom_data import AtomDefaultColors, AtomIsGhost, AtomPeriodicTable from .coords import CoordsDataset, projected_1Dcoords, projected_2Dcoords -#from ...types import AtomsArgument, GeometryLike, PathLike +# from ...types import AtomsArgument, GeometryLike, PathLike GeometryDataset = CoordsDataset AtomsDataset = GeometryDataset @@ -34,9 +34,10 @@ # def geometry_from_obj(obj: GeometryLike) -> Geometry: # return Geometry.new(obj) + def tile_geometry(geometry: Geometry, nsc: Tuple[int, int, int]) -> Geometry: """Tiles a geometry along the three lattice vectors. - + Parameters ----------- geometry: sisl.Geometry @@ -51,6 +52,7 @@ def tile_geometry(geometry: Geometry, nsc: Tuple[int, int, int]) -> Geometry: return tiled_geometry + def find_all_bonds(geometry: Geometry, tol: float = 0.2) -> BondsDataset: """ Finds all bonds present in a geometry. @@ -76,22 +78,26 @@ def find_all_bonds(geometry: Geometry, tol: float = 0.2) -> BondsDataset: neighs: npt.NDArray[np.int32] = geometry.close(at, R=[0.1, 3])[-1] for neigh in neighs[neighs > at]: - summed_radius = pt.radius([abs(geometry.atoms[at].Z), abs(geometry.atoms[neigh % geometry.na].Z)]).sum() - bond_thresh = (1+tol) * summed_radius - if bond_thresh > fnorm(geometry[neigh] - geometry[at]): + summed_radius = pt.radius( + [abs(geometry.atoms[at].Z), abs(geometry.atoms[neigh % geometry.na].Z)] + ).sum() + bond_thresh = (1 + tol) * summed_radius + if bond_thresh > fnorm(geometry[neigh] - geometry[at]): bonds.append([at, neigh]) if len(bonds) == 0: bonds = np.empty((0, 2), dtype=np.int64) - return Dataset({ - "bonds": (("bond_index", "bond_atom"), np.array(bonds, dtype=np.int64)) - }, - coords={"bond_index": np.arange(len(bonds)), "bond_atom": [0, 1]}, - attrs={"geometry": geometry} + return Dataset( + {"bonds": (("bond_index", "bond_atom"), np.array(bonds, dtype=np.int64))}, + coords={"bond_index": np.arange(len(bonds)), "bond_atom": [0, 1]}, + attrs={"geometry": geometry}, ) -def get_atoms_bonds(bonds: npt.NDArray[np.int32], atoms: npt.ArrayLike, ret_mask: bool = False) -> npt.NDArray[Union[np.float64, np.bool8]]: + +def get_atoms_bonds( + bonds: npt.NDArray[np.int32], atoms: npt.ArrayLike, ret_mask: bool = False +) -> npt.NDArray[Union[np.float64, np.bool8]]: """Gets the bonds where the given atoms are involved. Parameters @@ -105,12 +111,15 @@ def get_atoms_bonds(bonds: npt.NDArray[np.int32], atoms: npt.ArrayLike, ret_mask mask = np.isin(bonds, atoms).any(axis=-1) if ret_mask: return mask - + return bonds[mask] -def sanitize_atoms(geometry: Geometry, atoms: AtomsArgument = None) -> npt.NDArray[np.int32]: + +def sanitize_atoms( + geometry: Geometry, atoms: AtomsArgument = None +) -> npt.NDArray[np.int32]: """Sanitizes the atoms argument to a np.ndarray of shape (natoms,). - + This is the same as `geometry._sanitize_atoms` but ensuring that the result is a numpy array of 1 dimension. @@ -124,7 +133,10 @@ def sanitize_atoms(geometry: Geometry, atoms: AtomsArgument = None) -> npt.NDArr atoms = geometry._sanitize_atoms(atoms) return np.atleast_1d(atoms) -def tile_data_sc(geometry_data: GeometryDataset, nsc: Tuple[int, int, int] = (1, 1, 1)) -> GeometryDataset: + +def tile_data_sc( + geometry_data: GeometryDataset, nsc: Tuple[int, int, int] = (1, 1, 1) +) -> GeometryDataset: """Tiles coordinates from unit cell to a supercell. Parameters @@ -140,42 +152,50 @@ def tile_data_sc(geometry_data: GeometryDataset, nsc: Tuple[int, int, int] = (1, xyz_shape = geometry_data.xyz.shape # Create a fake geometry - fake_geom = Geometry(xyz=geometry_data.xyz.values.reshape(-1, 3), - lattice=geometry_data.geometry.lattice.copy(), - atoms=1 + fake_geom = Geometry( + xyz=geometry_data.xyz.values.reshape(-1, 3), + lattice=geometry_data.geometry.lattice.copy(), + atoms=1, ) sc_offs = np.array(list(itertools.product(*[range(n) for n in nsc]))) - - sc_xyz = np.array([ - fake_geom.axyz(isc=sc_off) for sc_off in sc_offs - ]).reshape((total_sc, *xyz_shape)) - + + sc_xyz = np.array([fake_geom.axyz(isc=sc_off) for sc_off in sc_offs]).reshape( + (total_sc, *xyz_shape) + ) + # Build the new dataset sc_atoms = geometry_data.assign({"xyz": (("isc", *geometry_data.xyz.dims), sc_xyz)}) sc_atoms = sc_atoms.assign_coords(isc=range(total_sc)) - + return sc_atoms -def stack_sc_data(geometry_data: GeometryDataset, newname: str, dims: Sequence[str]) -> GeometryDataset: + +def stack_sc_data( + geometry_data: GeometryDataset, newname: str, dims: Sequence[str] +) -> GeometryDataset: """Stacks the supercell coordinate with others. - + Parameters ----------- geometry_data: GeometryDataset the dataset for which we want to stack the supercell coordinates. newname: str """ - + return geometry_data.stack(**{newname: ["isc", *dims]}).transpose(newname, ...) + class AtomsStyleSpec(TypedDict): color: Any size: Any opacity: Any vertices: Any -def parse_atoms_style(geometry: Geometry, atoms_style: Sequence[AtomsStyleSpec], scale: float = 1.) -> AtomsDataset: + +def parse_atoms_style( + geometry: Geometry, atoms_style: Sequence[AtomsStyleSpec], scale: float = 1.0 +) -> AtomsDataset: """Parses atom style specifications to a dataset of styles. Parameters @@ -195,10 +215,10 @@ def parse_atoms_style(geometry: Geometry, atoms_style: Sequence[AtomsStyleSpec], { "color": AtomDefaultColors(), "size": AtomPeriodicTable(what="radius"), - "opacity": AtomIsGhost(fill_true=0.4, fill_false=1.), + "opacity": AtomIsGhost(fill_true=0.4, fill_false=1.0), "vertices": 15, }, - *atoms_style + *atoms_style, ] def _tile_if_needed(atoms, spec): @@ -215,9 +235,9 @@ def _tile_if_needed(atoms, spec): # Initialize the styles. parsed_atoms_style = { - "color": np.empty((geometry.na, ), dtype=object), - "size": np.empty((geometry.na, ), dtype=float), - "vertices": np.empty((geometry.na, ), dtype=int), + "color": np.empty((geometry.na,), dtype=object), + "size": np.empty((geometry.na,), dtype=float), + "vertices": np.empty((geometry.na,), dtype=int), "opacity": np.empty((geometry.na), dtype=float), } @@ -235,34 +255,41 @@ def _tile_if_needed(atoms, spec): parsed_atoms_style[key][atoms] = _tile_if_needed(atoms, style) # Apply the scale - parsed_atoms_style['size'] = parsed_atoms_style['size'] * scale + parsed_atoms_style["size"] = parsed_atoms_style["size"] * scale # Convert colors to numbers if possible try: - parsed_atoms_style['color'] = parsed_atoms_style['color'].astype(float) + parsed_atoms_style["color"] = parsed_atoms_style["color"].astype(float) except: pass # Add coordinates to the values according to their unique dimensionality. data_vars = {} for k, value in parsed_atoms_style.items(): - if (k != "color" or value.dtype not in (float, int)): + if k != "color" or value.dtype not in (float, int): unique = np.unique(value) if len(unique) == 1: data_vars[k] = unique[0] continue data_vars[k] = ("atom", value) - + return Dataset( data_vars, coords={"atom": range(geometry.na)}, attrs={"geometry": geometry}, ) -def sanitize_arrows(geometry: Geometry, arrows: Sequence[AtomArrowSpec], atoms: AtomsArgument, ndim: int, axes: Sequence[str]) -> List[dict]: + +def sanitize_arrows( + geometry: Geometry, + arrows: Sequence[AtomArrowSpec], + atoms: AtomsArgument, + ndim: int, + axes: Sequence[str], +) -> List[dict]: """Sanitizes a list of arrow specifications. - - Each arrow specification in the output has the atoms sanitized and + + Each arrow specification in the output has the atoms sanitized and the data with the shape (natoms, ndim). Parameters @@ -284,17 +311,21 @@ def sanitize_arrows(geometry: Geometry, arrows: Sequence[AtomArrowSpec], atoms: def _sanitize_spec(arrow_spec): arrow_spec = AtomArrowSpec(**arrow_spec) arrow_spec = asdict(arrow_spec) - - arrow_spec["atoms"] = np.atleast_1d(geometry._sanitize_atoms(arrow_spec["atoms"])) + + arrow_spec["atoms"] = np.atleast_1d( + geometry._sanitize_atoms(arrow_spec["atoms"]) + ) arrow_atoms = arrow_spec["atoms"] not_displayed = set(arrow_atoms) - set(atoms) if not_displayed: - warn(f"Arrow data for atoms {not_displayed} will not be displayed because these atoms are not displayed.") + warn( + f"Arrow data for atoms {not_displayed} will not be displayed because these atoms are not displayed." + ) if set(atoms) == set(atoms) - set(arrow_atoms): # Then it makes no sense to store arrows, as nothing will be drawn return None - + arrow_data = np.full((geometry.na, ndim), np.nan, dtype=np.float64) provided_data = np.array(arrow_spec["data"]) @@ -303,12 +334,14 @@ def _sanitize_spec(arrow_spec): provided_data = projected_1Dcoords(geometry, provided_data, axis=axes[0]) provided_data = np.expand_dims(provided_data, axis=-1) elif ndim == 2: - provided_data = projected_2Dcoords(geometry, provided_data, xaxis=axes[0], yaxis=axes[1]) + provided_data = projected_2Dcoords( + geometry, provided_data, xaxis=axes[0], yaxis=axes[1] + ) arrow_data[arrow_atoms] = provided_data arrow_spec["data"] = arrow_data[atoms] - #arrow_spec["data"] = self._tile_atomic_data(arrow_spec["data"]) + # arrow_spec["data"] = self._tile_atomic_data(arrow_spec["data"]) return arrow_spec @@ -322,28 +355,37 @@ def _sanitize_spec(arrow_spec): return [arrow_spec for arrow_spec in san_arrows if arrow_spec is not None] + def add_xyz_to_dataset(dataset: AtomsDataset) -> AtomsDataset: """Adds the xyz data variable to a dataset with associated geometry. The new xyz data variable contains the coordinates of the atoms. - + Parameters ----------- dataset: AtomsDataset the dataset to be augmented with xyz data. """ - geometry = dataset.attrs['geometry'] + geometry = dataset.attrs["geometry"] - xyz_ds = Dataset({"xyz": (("atom", "axis"), geometry.xyz)}, coords={"axis": [0,1,2]}, attrs={"geometry": geometry}) + xyz_ds = Dataset( + {"xyz": (("atom", "axis"), geometry.xyz)}, + coords={"axis": [0, 1, 2]}, + attrs={"geometry": geometry}, + ) return xyz_ds.merge(dataset, combine_attrs="no_conflicts") + class BondsStyleSpec(TypedDict): color: Any width: Any opacity: Any -def style_bonds(bonds_data: BondsDataset, bonds_style: BondsStyleSpec, scale: float = 1.) -> BondsDataset: + +def style_bonds( + bonds_data: BondsDataset, bonds_style: BondsStyleSpec, scale: float = 1.0 +) -> BondsDataset: """Adds styles to a bonds dataset. Parameters @@ -367,15 +409,15 @@ def style_bonds(bonds_data: BondsDataset, bonds_style: BondsStyleSpec, scale: fl "width": 1, "opacity": 1, }, - bonds_style + bonds_style, ] # Initialize the styles. # Potentially bond styles could have two styles, one for each halve. parsed_bonds_style = { - "color": np.empty((nbonds, ), dtype=object), - "width": np.empty((nbonds, ), dtype=float), - "opacity": np.empty((nbonds, ), dtype=float), + "color": np.empty((nbonds,), dtype=object), + "width": np.empty((nbonds,), dtype=float), + "opacity": np.empty((nbonds,), dtype=float), } # Go specification by specification and apply the styles @@ -391,19 +433,18 @@ def style_bonds(bonds_data: BondsDataset, bonds_style: BondsStyleSpec, scale: fl parsed_bonds_style[key][:] = style - # Apply the scale - parsed_bonds_style['width'] = parsed_bonds_style['width'] * scale + parsed_bonds_style["width"] = parsed_bonds_style["width"] * scale # Convert colors to float datatype if possible try: - parsed_bonds_style['color'] = parsed_bonds_style['color'].astype(float) + parsed_bonds_style["color"] = parsed_bonds_style["color"].astype(float) except ValueError: pass # Add coordinates to the values according to their unique dimensionality. data_vars = {} for k, value in parsed_bonds_style.items(): - if (k != "color" or value.dtype not in (float, int)): + if k != "color" or value.dtype not in (float, int): unique = np.unique(value) if len(unique) == 1: data_vars[k] = unique[0] @@ -413,15 +454,16 @@ def style_bonds(bonds_data: BondsDataset, bonds_style: BondsStyleSpec, scale: fl return bonds_data.assign(data_vars) + def add_xyz_to_bonds_dataset(bonds_data: BondsDataset) -> BondsDataset: """Adds the coordinates of the bonds endpoints to a bonds dataset. - + Parameters ----------- bonds_data: BondsDataset the bonds dataset to be augmented with xyz data. """ - geometry = bonds_data.attrs['geometry'] + geometry = bonds_data.attrs["geometry"] def _bonds_xyz(ds): bonds_shape = ds.bonds.shape @@ -430,9 +472,15 @@ def _bonds_xyz(ds): return bonds_data.assign({"xyz": _bonds_xyz}) -def sanitize_bonds_selection(bonds_data: BondsDataset, atoms: Optional[npt.NDArray[np.int32]] = None, bind_bonds_to_ats: bool = False, show_bonds: bool = True) -> Union[np.ndarray, None]: + +def sanitize_bonds_selection( + bonds_data: BondsDataset, + atoms: Optional[npt.NDArray[np.int32]] = None, + bind_bonds_to_ats: bool = False, + show_bonds: bool = True, +) -> Union[np.ndarray, None]: """Sanitizes bonds selection, unifying multiple parameters into a single value - + Parameters ----------- bonds_data: BondsDataset @@ -440,12 +488,12 @@ def sanitize_bonds_selection(bonds_data: BondsDataset, atoms: Optional[npt.NDArr atoms: np.ndarray of shape (natoms,) the atoms for which we want to keep the bonds. bind_bonds_to_ats: bool - if True, the bonds will be bound to the atoms, - so that if an atom is not displayed, its bonds + if True, the bonds will be bound to the atoms, + so that if an atom is not displayed, its bonds will not be displayed either. show_bonds: bool if False, no bonds will be displayed. - """ + """ if not show_bonds: return np.array([], dtype=np.int64) elif bind_bonds_to_ats and atoms is not None: @@ -453,12 +501,13 @@ def sanitize_bonds_selection(bonds_data: BondsDataset, atoms: Optional[npt.NDArr else: return None + def bonds_to_lines(bonds_data: BondsDataset, points_per_bond: int = 2) -> BondsDataset: """Computes intermediate points between the endpoints of the bonds by interpolation. - + Bonds are concatenated into a single dimension "point index", and NaNs are added between bonds. - + Parameters ----------- bonds_data: BondsDataset @@ -470,16 +519,19 @@ def bonds_to_lines(bonds_data: BondsDataset, points_per_bond: int = 2) -> BondsD if points_per_bond > 2: bonds_data = bonds_data.interp(bond_atom=np.linspace(0, 1, points_per_bond)) - bonds_data = bonds_data.reindex({"bond_atom": [*bonds_data.bond_atom.values, 2]}).stack(point_index=bonds_data.xyz.dims[:-1]) - + bonds_data = bonds_data.reindex( + {"bond_atom": [*bonds_data.bond_atom.values, 2]} + ).stack(point_index=bonds_data.xyz.dims[:-1]) + return bonds_data + def sites_obj_to_geometry(sites_obj: BrillouinZone): """Converts anything that contains sites into a geometry. - + Possible conversions: - BrillouinZone object to geometry, kpoints to atoms. - + Parameters ----------- sites_obj @@ -489,13 +541,14 @@ def sites_obj_to_geometry(sites_obj: BrillouinZone): if isinstance(sites_obj, BrillouinZone): return Geometry(sites_obj.k.dot(sites_obj.rcell), lattice=sites_obj.rcell) else: - raise ValueError(f"Cannot convert {sites_obj.__class__.__name__} to a geometry.") - + raise ValueError( + f"Cannot convert {sites_obj.__class__.__name__} to a geometry." + ) + + def get_sites_units(sites_obj: BrillouinZone): """Units of space for an object that is to be converted into a geometry""" if isinstance(sites_obj, BrillouinZone): return "1/Ang" else: return "" - - diff --git a/src/sisl/viz/processors/grid.py b/src/sisl/viz/processors/grid.py index 986cef8054..11e46f0ae0 100644 --- a/src/sisl/viz/processors/grid.py +++ b/src/sisl/viz/processors/grid.py @@ -14,10 +14,14 @@ from .cell import infer_cell_axes, is_1D_cartesian, is_cartesian_unordered -#from ...types import Axis, PathLike -#from ..data_sources import DataSource +# from ...types import Axis, PathLike +# from ..data_sources import DataSource -def get_grid_representation(grid: Grid, represent: Literal['real', 'imag', 'mod', 'phase', 'rad_phase', 'deg_phase']) -> Grid: + +def get_grid_representation( + grid: Grid, + represent: Literal["real", "imag", "mod", "phase", "rad_phase", "deg_phase"], +) -> Grid: """Returns a representation of the grid Parameters @@ -31,29 +35,43 @@ def get_grid_representation(grid: Grid, represent: Literal['real', 'imag', 'mod' ------------ sisl.Grid """ - def _func(values: npt.NDArray[Union[np.int_, np.float_, np.complex_]]) -> npt.NDArray: - if represent == 'real': + + def _func( + values: npt.NDArray[Union[np.int_, np.float_, np.complex_]] + ) -> npt.NDArray: + if represent == "real": new_values = values.real - elif represent == 'imag': + elif represent == "imag": new_values = values.imag - elif represent == 'mod': + elif represent == "mod": new_values = np.absolute(values) - elif represent in ['phase', 'rad_phase', 'deg_phase']: + elif represent in ["phase", "rad_phase", "deg_phase"]: new_values = np.angle(values, deg=represent.startswith("deg")) else: - raise ValueError(f"'{represent}' is not a valid value for the `represent` argument") + raise ValueError( + f"'{represent}' is not a valid value for the `represent` argument" + ) return new_values return grid.apply(_func) + def tile_grid(grid: Grid, nsc: Tuple[int, int, int] = (1, 1, 1)) -> Grid: """Tiles the grid""" for ax, reps in enumerate(nsc): grid = grid.tile(reps, ax) return grid -def transform_grid_cell(grid: Grid, cell: npt.NDArray[np.float_] = np.eye(3), output_shape: Optional[Tuple[int, int, int]] = None, mode: str = "constant", order: int = 1, **kwargs) -> Grid: + +def transform_grid_cell( + grid: Grid, + cell: npt.NDArray[np.float_] = np.eye(3), + output_shape: Optional[Tuple[int, int, int]] = None, + mode: str = "constant", + order: int = 1, + **kwargs, +) -> Grid: """Applies a linear transformation to the grid to get it relative to an arbitrary cell. This method can be used, for example to get the values of the grid with respect to @@ -70,7 +88,7 @@ def transform_grid_cell(grid: Grid, cell: npt.NDArray[np.float_] = np.eye(3), ou the minimum bounding box necessary to accomodate the unit cell. output_shape: array-like of int of shape (3,), optional the shape of the final output. If not provided, the current shape of the grid - will be used. + will be used. Notice however that if the transformation applies a big shear to the image (grid) you will probably need to have a bigger output_shape. @@ -105,7 +123,7 @@ def transform_grid_cell(grid: Grid, cell: npt.NDArray[np.float_] = np.eye(3), ou # Create the transformation matrix. Since we want to control the shape # of the output, we can not use grid.dcell directly, we need to modify it. scales = output_shape / lengths - forward_t = (grid.dcell.dot(inv_cell)*scales).T + forward_t = (grid.dcell.dot(inv_cell) * scales).T # Scipy's affine transform asks for the inverse transformation matrix, to # map from output pixels to input pixels. By taking the inverse of our @@ -123,12 +141,19 @@ def transform_grid_cell(grid: Grid, cell: npt.NDArray[np.float_] = np.eye(3), ou offset = center_input - tr.dot(center_output) # We pass all the parameters to scipy's affine_transform - transformed_image = affine_transform(grid.grid, tr, order=1, offset=offset, - output_shape=output_shape, mode=mode, **kwargs) + transformed_image = affine_transform( + grid.grid, + tr, + order=1, + offset=offset, + output_shape=output_shape, + mode=mode, + **kwargs, + ) # Create a new grid with the new shape and the new cell (notice how the cell # is rescaled from the input cell to fit the actual coordinates of the system) - new_grid = grid.__class__((1, 1, 1), lattice=cell*lengths.reshape(3, 1)) + new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1)) new_grid.grid = transformed_image new_grid.geometry = grid.geometry new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset)) @@ -136,9 +161,15 @@ def transform_grid_cell(grid: Grid, cell: npt.NDArray[np.float_] = np.eye(3), ou # Find the offset between the origin before and after the transformation return new_grid -def orthogonalize_grid(grid:Grid, interp: Tuple[int, int, int] = (1, 1, 1), mode: str = "constant", **kwargs) -> Grid: + +def orthogonalize_grid( + grid: Grid, + interp: Tuple[int, int, int] = (1, 1, 1), + mode: str = "constant", + **kwargs, +) -> Grid: """Transform grid cell to be orthogonal. - + Uses `transform_grid_cell`. Parameters @@ -149,19 +180,30 @@ def orthogonalize_grid(grid:Grid, interp: Tuple[int, int, int] = (1, 1, 1), mode Number of times that the grid should be augmented for each lattice vector. mode: str, optional - determines how to handle borders. + determines how to handle borders. See `transform_grid_cell` for more info on the possible values. **kwargs: the rest of keyword arguments are passed directly to `transform_grid_cell` """ return transform_grid_cell( - grid, mode=mode, output_shape=tuple(interp[i] * grid.shape[i] for i in range(3)), cval=np.nan, **kwargs + grid, + mode=mode, + output_shape=tuple(interp[i] * grid.shape[i] for i in range(3)), + cval=np.nan, + **kwargs, ) -def orthogonalize_grid_if_needed(grid: Grid, axes: Sequence[str], tol: float = 1e-3, - interp: Tuple[int, int, int] = (1, 1, 1), mode: str = "constant", **kwargs) -> Grid: + +def orthogonalize_grid_if_needed( + grid: Grid, + axes: Sequence[str], + tol: float = 1e-3, + interp: Tuple[int, int, int] = (1, 1, 1), + mode: str = "constant", + **kwargs, +) -> Grid: """Same as `orthogonalize_grid`, but first checks if it is really needed. - + Parameters ----------- grid: sisl.Grid @@ -180,16 +222,19 @@ def orthogonalize_grid_if_needed(grid: Grid, axes: Sequence[str], tol: float = 1 the rest of keyword arguments are passed directly to `transform_grid_cell` """ - should_ortogonalize = should_transform_grid_cell_plotting(grid=grid, axes=axes, tol=tol) + should_ortogonalize = should_transform_grid_cell_plotting( + grid=grid, axes=axes, tol=tol + ) if should_ortogonalize: grid = orthogonalize_grid(grid, interp=interp, mode=mode, **kwargs) return grid + def apply_transform(grid: Grid, transform: Union[Callable, str]) -> Grid: """applies a transformation to the grid. - + Parameters ----------- grid: sisl.Grid @@ -209,9 +254,10 @@ def apply_transform(grid: Grid, transform: Union[Callable, str]) -> Grid: return grid.apply(transform) + def apply_transforms(grid: Grid, transforms: Sequence[Union[Callable, str]]) -> Grid: """Applies multiple transformations sequentially - + Parameters ----------- grid: sisl.Grid @@ -225,9 +271,12 @@ def apply_transforms(grid: Grid, transforms: Sequence[Union[Callable, str]]) -> grid = apply_transform(grid, transform) return grid -def reduce_grid(grid: Grid, reduce_method: Literal["average", "sum"], keep_axes: Sequence[int]) -> Grid: + +def reduce_grid( + grid: Grid, reduce_method: Literal["average", "sum"], keep_axes: Sequence[int] +) -> Grid: """Reduces the grid along multiple axes - + Parameters ----------- grid: sisl.Grid @@ -245,15 +294,16 @@ def reduce_grid(grid: Grid, reduce_method: Literal["average", "sum"], keep_axes: grid = getattr(grid, reduce_method)(ax) grid.origin[:] = old_origin - + return grid + def sub_grid( - grid: Grid, - x_range: Optional[Tuple[float, float]] = None, - y_range: Optional[Tuple[float, float]] = None, + grid: Grid, + x_range: Optional[Tuple[float, float]] = None, + y_range: Optional[Tuple[float, float]] = None, z_range: Optional[Tuple[float, float]] = None, - cart_tol: float = 1e-3 + cart_tol: float = 1e-3, ) -> Grid: """Returns only the part of the grid that is within the specified ranges. @@ -282,16 +332,17 @@ def sub_grid( cell = grid.lattice.cell origin = grid.origin.copy() - + # Get only the part of the grid that we need ax_ranges = [x_range, y_range, z_range] directions = ["x", "y", "z"] for ax, (ax_range, direction) in enumerate(zip(ax_ranges, directions)): if ax_range is not None: - # Cartesian check if not is_1D_cartesian(cell, direction, tol=cart_tol): - raise ValueError(f"Cannot sub grid along '{direction}', since there is no unique lattice vector that represents this direction. Cell: {cell}") + raise ValueError( + f"Cannot sub grid along '{direction}', since there is no unique lattice vector that represents this direction. Cell: {cell}" + ) # Find out which lattice vector represents the direction lattice_ax = np.where(cell[:, ax] > cart_tol)[0][0] @@ -300,7 +351,9 @@ def sub_grid( lims = np.zeros((2, 3)) # If the cell was transformed, then we need to modify # the range to get what the user wants. - lims[:, ax] = ax_range #+ self.offsets["cell_transform"][ax] - self.offsets["origin"][ax] + lims[ + :, ax + ] = ax_range # + self.offsets["cell_transform"][ax] - self.offsets["origin"][ax] origin[ax] += ax_range[0] @@ -308,16 +361,21 @@ def sub_grid( indices = np.array([grid.index(lim) for lim in lims], dtype=int) # And finally get the subpart of the grid - grid = grid.sub(np.arange(indices[0, lattice_ax], indices[1, lattice_ax] + 1), lattice_ax) + grid = grid.sub( + np.arange(indices[0, lattice_ax], indices[1, lattice_ax] + 1), + lattice_ax, + ) - grid.origin[:] = origin - + return grid -def interpolate_grid(grid: Grid, interp: Tuple[int, int, int] = (1, 1, 1), force: bool = False) -> Grid: + +def interpolate_grid( + grid: Grid, interp: Tuple[int, int, int] = (1, 1, 1), force: bool = False +) -> Grid: """Interpolates the grid. - + It also makes sure that the grid is not interpolated over dimensions that only contain one value, unless `force` is True. @@ -334,7 +392,7 @@ def interpolate_grid(grid: Grid, interp: Tuple[int, int, int] = (1, 1, 1), force Whether to force the interpolation over dimensions that only contain one value. """ - + grid_shape = np.array(grid.shape) interp_factors = np.array(interp) @@ -348,29 +406,35 @@ def interpolate_grid(grid: Grid, interp: Tuple[int, int, int] = (1, 1, 1), force return grid -def grid_geometry(grid: Grid, geometry: Optional[Geometry] = None) -> Union[Geometry, None]: + +def grid_geometry( + grid: Grid, geometry: Optional[Geometry] = None +) -> Union[Geometry, None]: """Returns the geometry associated with the grid. - + Parameters ----------- grid: sisl.Grid The grid for which we want to get the geometry. geometry: sisl.Geometry, optional If provided, this geometry will be returned instead of the one - associated with the grid. + associated with the grid. """ if geometry is None: geometry = getattr(grid, "geometry", None) - + return geometry -def should_transform_grid_cell_plotting(grid: Grid, axes: Sequence[str], tol: float = 1e-3) -> bool: + +def should_transform_grid_cell_plotting( + grid: Grid, axes: Sequence[str], tol: float = 1e-3 +) -> bool: """Determines whether the grid should be transformed for plotting. It takes into account the axes that will be plotted and checks if the grid is skewed in any of those directions. If it is, it will return True, meaning that the grid should be transformed before plotting. - + Parameters ----------- grid: sisl.Grid @@ -387,7 +451,9 @@ def should_transform_grid_cell_plotting(grid: Grid, axes: Sequence[str], tol: fl should_orthogonalize = not is_cartesian_unordered(grid, tol=tol) and len(axes) < 3 # We also don't need to orthogonalize if cartesian coordinates are not requested # (this would mean that axes is a combination of "a", "b" and "c") - should_orthogonalize = should_orthogonalize and bool(set(axes).intersection(["x", "y", "z"])) + should_orthogonalize = should_orthogonalize and bool( + set(axes).intersection(["x", "y", "z"]) + ) if should_orthogonalize and ndim == 1: # In 1D representations, even if the cell is skewed, we might not need to transform. @@ -397,12 +463,13 @@ def should_transform_grid_cell_plotting(grid: Grid, axes: Sequence[str], tol: fl # first two axes, as they don't contribute in the Z direction. Also, it is required that # "c" doesn't contribute to any of the other two directions. should_orthogonalize &= not is_1D_cartesian(grid, axes[0], tol=tol) - + return should_orthogonalize + def get_grid_axes(grid: Grid, axes: Sequence[str]) -> List[int]: """Returns the indices of the lattice vectors that correspond to the axes. - + If axes is of length 3 (i.e. a 3D view), this function always returns [0, 1, 2] regardless of what the axes are. @@ -425,11 +492,16 @@ def get_grid_axes(grid: Grid, axes: Sequence[str]) -> List[int]: return grid_axes -def get_ax_vals(grid: Grid, ax: Literal[0,1,2,"a","b","c","x","y","z"], nsc: Tuple[int, int, int]) -> npt.NDArray[np.float_]: + +def get_ax_vals( + grid: Grid, + ax: Literal[0, 1, 2, "a", "b", "c", "x", "y", "z"], + nsc: Tuple[int, int, int], +) -> npt.NDArray[np.float_]: """Returns the values of a given axis on all grid points. These can be used for example as axes ticks on a plot. - + Parameters ---------- grid: sisl.Grid @@ -446,16 +518,19 @@ def get_ax_vals(grid: Grid, ax: Literal[0,1,2,"a","b","c","x","y","z"], nsc: Tup else: ax_index = {"x": 0, "y": 1, "z": 2}[ax] - ax_vals = np.arange(0, grid.cell[ax_index, ax_index], grid.dcell[ax_index, ax_index]) + get_offset(grid, ax) + ax_vals = np.arange( + 0, grid.cell[ax_index, ax_index], grid.dcell[ax_index, ax_index] + ) + get_offset(grid, ax) if len(ax_vals) == grid.shape[ax_index] + 1: ax_vals = ax_vals[:-1] return ax_vals -def get_offset(grid: Grid, ax: Literal[0,1,2,"a","b","c","x","y","z"]) -> float: + +def get_offset(grid: Grid, ax: Literal[0, 1, 2, "a", "b", "c", "x", "y", "z"]) -> float: """Returns the offset of the grid along a certain axis. - + Parameters ----------- grid: sisl.Grid @@ -463,17 +538,20 @@ def get_offset(grid: Grid, ax: Literal[0,1,2,"a","b","c","x","y","z"]) -> float: ax: {"x", "y", "z", "a", "b", "c", 0, 1, 2} The axis for which we want the offset. """ - + if isinstance(ax, int) or ax in ("a", "b", "c"): return 0 else: coord_index = "xyz".index(ax) return grid.origin[coord_index] + GridDataArray = DataArray -def grid_to_dataarray(grid: Grid, axes: Sequence[str], grid_axes: Sequence[int], nsc: Tuple[int, int, int]) -> GridDataArray: +def grid_to_dataarray( + grid: Grid, axes: Sequence[str], grid_axes: Sequence[int], nsc: Tuple[int, int, int] +) -> GridDataArray: transpose_grid_axes = [*grid_axes] for ax in (0, 1, 2): if ax not in transpose_grid_axes: @@ -484,18 +562,18 @@ def grid_to_dataarray(grid: Grid, axes: Sequence[str], grid_axes: Sequence[int], arr = DataArray( values, coords=[ - (k, get_ax_vals(grid, ax, nsc=nsc)) - for k, ax in zip(["x", "y", "z"], axes) - ] + (k, get_ax_vals(grid, ax, nsc=nsc)) for k, ax in zip(["x", "y", "z"], axes) + ], ) - arr.attrs['grid'] = grid + arr.attrs["grid"] = grid return arr + def get_isos(data: GridDataArray, isos: Sequence[dict]) -> List[dict]: """Gets the iso surfaces or isocontours of an array of data. - + Parameters ----------- data: DataArray @@ -505,17 +583,17 @@ def get_isos(data: GridDataArray, isos: Sequence[dict]) -> List[dict]: """ from skimage.measure import find_contours - #values = data['values'].values + # values = data['values'].values values = data.values isos_to_draw = [] - + # Get the dimensionality of the data ndim = values.ndim - + if len(isos) > 0 or ndim == 3: minval = np.nanmin(values) maxval = np.nanmax(values) - + # Prepare things for each possible dimensionality if ndim == 1: # For now, we don't calculate 1D "isopoints" @@ -524,14 +602,14 @@ def get_isos(data: GridDataArray, isos: Sequence[dict]) -> List[dict]: # Get the partition size dx = data.x[1] - data.x[0] dy = data.y[1] - data.y[0] - + # Function to get the coordinates from indices def _indices_to_2Dspace(contour_coords): return contour_coords.dot([[dx, 0, 0], [0, dy, 0]]) - + def _calc_iso(isoval): contours = find_contours(values, isoval) - + contour_xs = [] contour_ys = [] for contour in contours: @@ -545,48 +623,56 @@ def _calc_iso(isoval): # Add the information about this isoline to the list of isolines return { - "x": contour_xs, "y": contour_ys, "width": iso.get("width"), + "x": contour_xs, + "y": contour_ys, + "width": iso.get("width"), } - + elif ndim == 3: # In 3D, use default isosurfaces if none were provided. if len(isos) == 0 and maxval != minval: - default_iso_frac = 0.3 #isos_param["frac"].default + default_iso_frac = 0.3 # isos_param["frac"].default # If the default frac is 0.3, they will be displayed at 0.3 and 0.7 - isos = [ - {"frac": default_iso_frac}, - {"frac": 1-default_iso_frac} - ] - + isos = [{"frac": default_iso_frac}, {"frac": 1 - default_iso_frac}] + # Define the function that will calculate each isosurface def _calc_iso(isoval): - vertices, faces, normals, intensities = data.grid.isosurface(isoval, iso.get("step_size", 1)) + vertices, faces, normals, intensities = data.grid.isosurface( + isoval, iso.get("step_size", 1) + ) + + # vertices = vertices + self._get_offsets(grid) + self.offsets["origin"] - #vertices = vertices + self._get_offsets(grid) + self.offsets["origin"] - return {"vertices": vertices, "faces": faces} + else: raise ValueError(f"Dimensionality must be lower than 3, but is {ndim}") - + # Now loop through all the isos for iso in isos: if not iso.get("active", True): continue - + # Infer the iso value either from val or from frac isoval = iso.get("val") if isoval is None: frac = iso.get("frac") if frac is None: - raise ValueError(f"You are providing an iso query without 'val' and 'frac'. There's no way to know the isovalue!\nquery: {iso}") - isoval = minval + (maxval-minval)*frac - - isos_to_draw.append({ - "color": iso.get("color"), "opacity": iso.get("opacity"), - "name": iso.get("name", "Iso: $isoval$").replace("$isoval$", f"{isoval:.4f}"), - **_calc_iso(isoval), - }) + raise ValueError( + f"You are providing an iso query without 'val' and 'frac'. There's no way to know the isovalue!\nquery: {iso}" + ) + isoval = minval + (maxval - minval) * frac + + isos_to_draw.append( + { + "color": iso.get("color"), + "opacity": iso.get("opacity"), + "name": iso.get("name", "Iso: $isoval$").replace( + "$isoval$", f"{isoval:.4f}" + ), + **_calc_iso(isoval), + } + ) return isos_to_draw - diff --git a/src/sisl/viz/processors/logic.py b/src/sisl/viz/processors/logic.py index 8253b4cd77..5c3dcfd08e 100644 --- a/src/sisl/viz/processors/logic.py +++ b/src/sisl/viz/processors/logic.py @@ -3,6 +3,7 @@ T1 = TypeVar("T1") T2 = TypeVar("T2") + def swap(val: Union[T1, T2], vals: Tuple[T1, T2]) -> Union[T1, T2]: """Given two values, returns the one that is not the input value.""" if val == vals[0]: @@ -12,10 +13,14 @@ def swap(val: Union[T1, T2], vals: Tuple[T1, T2]) -> Union[T1, T2]: else: raise ValueError(f"Value {val} not in {vals}") -def matches(first: Any, second: Any, ret_true: T1 = True, ret_false: T2 = False) -> Union[T1, T2]: + +def matches( + first: Any, second: Any, ret_true: T1 = True, ret_false: T2 = False +) -> Union[T1, T2]: """If first matches second, return ret_true, else return ret_false.""" return ret_true if first == second else ret_false - + + def switch(obj: Any, ret_true: T1, ret_false: T2) -> Union[T1, T2]: """If obj is True, return ret_true, else return ret_false.""" - return ret_true if obj else ret_false \ No newline at end of file + return ret_true if obj else ret_false diff --git a/src/sisl/viz/processors/math.py b/src/sisl/viz/processors/math.py index 258eca217f..e8d8027d56 100644 --- a/src/sisl/viz/processors/math.py +++ b/src/sisl/viz/processors/math.py @@ -21,4 +21,4 @@ def normalize(data, vmin=0, vmax=1): data = np.asarray(data) data_min = np.min(data) data_max = np.max(data) - return vmin + (vmax - vmin) * (data - data_min) / (data_max - data_min) \ No newline at end of file + return vmin + (vmax - vmin) * (data - data_min) / (data_max - data_min) diff --git a/src/sisl/viz/processors/orbital.py b/src/sisl/viz/processors/orbital.py index 8edba6aa1a..8f33709c3d 100644 --- a/src/sisl/viz/processors/orbital.py +++ b/src/sisl/viz/processors/orbital.py @@ -26,10 +26,12 @@ class OrbitalGroup(TypedDict): reduce_func: Optional[Callable] spin_reduce: Optional[Callable] + class OrbitalQueriesManager: """ This class implements an input field that allows you to select orbitals by atom, species, etc... """ + _item_input_type = OrbitalStyleQuery _keys_to_cols = { @@ -44,59 +46,85 @@ class OrbitalQueriesManager: @singledispatchmethod @classmethod - def new(cls, geometry: Geometry, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}): + def new( + cls, + geometry: Geometry, + spin: Union[str, Spin] = "", + key_gens: Dict[str, Callable] = {}, + ): return cls(geometry=geometry, spin=spin or "", key_gens=key_gens) - + @new.register @classmethod - def from_geometry(cls, geometry: Geometry, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}): + def from_geometry( + cls, + geometry: Geometry, + spin: Union[str, Spin] = "", + key_gens: Dict[str, Callable] = {}, + ): return cls(geometry=geometry, spin=spin or "", key_gens=key_gens) - + @new.register @classmethod - def from_string(cls, - string: str, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {} + def from_string( + cls, + string: str, + spin: Union[str, Spin] = "", + key_gens: Dict[str, Callable] = {}, ): """Initializes an OrbitalQueriesManager from a string, assuming it is a path.""" return cls.new(Path(string), spin=spin, key_gens=key_gens) - + @new.register @classmethod - def from_path(cls, - path: Path, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {} + def from_path( + cls, path: Path, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {} ): """Initializes an OrbitalQueriesManager from a path, converting it to a sile.""" return cls.new(sisl.get_sile(path), spin=spin, key_gens=key_gens) @new.register @classmethod - def from_sile(cls, - sile: sisl.io.BaseSile, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}, + def from_sile( + cls, + sile: sisl.io.BaseSile, + spin: Union[str, Spin] = "", + key_gens: Dict[str, Callable] = {}, ): """Initializes an OrbitalQueriesManager from a sile.""" return cls.new(sile.read_geometry(), spin=spin, key_gens=key_gens) - + @new.register @classmethod - def from_xarray(cls, - array: xarray.core.common.AttrAccessMixin, spin: Optional[Union[str, Spin]] = None, key_gens: Dict[str, Callable] = {}, + def from_xarray( + cls, + array: xarray.core.common.AttrAccessMixin, + spin: Optional[Union[str, Spin]] = None, + key_gens: Dict[str, Callable] = {}, ): """Initializes an OrbitalQueriesManager from an xarray object.""" if spin is None: spin = array.attrs.get("spin", "") return cls.new(array.attrs.get("geometry"), spin=spin, key_gens=key_gens) - + @new.register @classmethod - def from_data(cls, - data: Data, spin: Optional[Union[str, Spin]] = None, key_gens: Dict[str, Callable] = {} + def from_data( + cls, + data: Data, + spin: Optional[Union[str, Spin]] = None, + key_gens: Dict[str, Callable] = {}, ): """Initializes an OrbitalQueriesManager from a sisl Data object.""" return cls.new(data._data, spin=spin, key_gens=key_gens) - def __init__(self, geometry: Optional[Geometry] = None, spin: Union[str, Spin] = "", key_gens: Dict[str, Callable] = {}): - + def __init__( + self, + geometry: Optional[Geometry] = None, + spin: Union[str, Spin] = "", + key_gens: Dict[str, Callable] = {}, + ): self.geometry = geometry self.spin = Spin(spin) @@ -144,11 +172,18 @@ def filter_df(self, df, query, key_to_cols, raise_not_active=False): if raise_not_active: if not query["active"]: - raise ValueError(f"Query {query} is not active and you are trying to use it") + raise ValueError( + f"Query {query} is not active and you are trying to use it" + ) query_str = [] for key, val in query.items(): - if key == "orbitals" and val is not None and len(val) > 0 and isinstance(val[0], int): + if ( + key == "orbitals" + and val is not None + and len(val) > 0 + and isinstance(val[0], int) + ): df = df.iloc[val] continue @@ -156,8 +191,8 @@ def filter_df(self, df, query, key_to_cols, raise_not_active=False): if key in df and val is not None: if isinstance(val, (np.ndarray, tuple)): val = np.ravel(val).tolist() - query_str.append(f'{key}=={repr(val)}') - + query_str.append(f"{key}=={repr(val)}") + if len(query_str) == 0: return df else: @@ -168,9 +203,8 @@ def _build_orb_filtering_df(self, geom): orb_props = defaultdict(list) del_key = set() - #Loop over all orbitals of the basis + # Loop over all orbitals of the basis for at, iorb in geom.iter_orbitals(): - atom = geom.atoms[at] orb = atom[iorb] @@ -212,7 +246,7 @@ def get_options(self, key, **kwargs): np.ndarray of shape (n_options, [n_keys]) all the possible options. - If only one key was provided, it is a one dimensional array. + If only one key was provided, it is a one dimensional array. Examples ----------- @@ -229,14 +263,21 @@ def get_options(self, key, **kwargs): if kwargs: if "atoms" in kwargs: kwargs["atoms"] = self.geometry._sanitize_atoms(kwargs["atoms"]) + def _repr(v): if isinstance(v, np.ndarray): v = list(v.ravel()) if isinstance(v, dict): raise Exception(str(v)) return repr(v) - query = ' & '.join([f'{self._keys_to_cols.get(k, k)}=={_repr(v)}' for k, v in kwargs.items( - ) if self._keys_to_cols.get(k, k) in df]) + + query = " & ".join( + [ + f"{self._keys_to_cols.get(k, k)}=={_repr(v)}" + for k, v in kwargs.items() + if self._keys_to_cols.get(k, k) in df + ] + ) if query: df = df.query(query) @@ -261,8 +302,7 @@ def _repr(v): # Now get the unique options from the dataframe if keys: - options = df.drop_duplicates(subset=keys)[ - keys].values.astype(object) + options = df.drop_duplicates(subset=keys)[keys].values.astype(object) else: # It might be the only key was "spin", then we are going to fake it # to get an options array that can be treated in the same way. @@ -272,7 +312,8 @@ def _repr(v): # account the position (column index) where they are expected to be returned. if spin_in_keys and len(spin_options) > 0: options = np.concatenate( - [np.insert(options, spin_key_i, spin, axis=1) for spin in spin_options]) + [np.insert(options, spin_key_i, spin, axis=1) for spin in spin_options] + ) # Squeeze the options array, just in case there is only one key # There's a special case: if there is only one option for that key, @@ -284,28 +325,31 @@ def _repr(v): return options def get_orbitals(self, query): - if "atoms" in query: query["atoms"] = self.geometry._sanitize_atoms(query["atoms"]) - filtered_df = self.filter_df( - self.orb_filtering_df, query, self._keys_to_cols - ) + filtered_df = self.filter_df(self.orb_filtering_df, query, self._keys_to_cols) return filtered_df.index.values - - def get_atoms(self, query): + def get_atoms(self, query): if "atoms" in query: query["atoms"] = self.geometry._sanitize_atoms(query["atoms"]) - filtered_df = self.filter_df( - self.orb_filtering_df, query, self._keys_to_cols - ) + filtered_df = self.filter_df(self.orb_filtering_df, query, self._keys_to_cols) - return np.unique(filtered_df['atom'].values) + return np.unique(filtered_df["atom"].values) - def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignore_constraints=False, **kwargs): + def _split_query( + self, + query, + on, + only=None, + exclude=None, + query_gen=None, + ignore_constraints=False, + **kwargs, + ): """ Splits a query into multiple queries based on one of its parameters. @@ -328,7 +372,7 @@ def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignor This may be useful, for example, to give each request a color, or a custom name. ignore_constraints: boolean or array-like, optional - determines whether constraints (imposed by the query that you want to split) + determines whether constraints (imposed by the query that you want to split) on the parameters that we want to split along should be taken into consideration. If `False`: all constraints considered. @@ -361,8 +405,12 @@ def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignor if ignore_constraints is False: ignore_constraints = () - constraints = {key: val for key, val in constraints.items() if key not in ignore_constraints and val is not None} - + constraints = { + key: val + for key, val in constraints.items() + if key not in ignore_constraints and val is not None + } + # Knowing what are our constraints (which may be none), get the available options values = self.get_options("+".join(on), **constraints) @@ -398,18 +446,20 @@ def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignor queries = [] for i, value in enumerate(values): if value not in exclude and (only is None or value in only): - # Use the name template to generate the name for this query name = base_name for key, val in zip(on, value): name = name.replace(f"${key}", str(val)) # Build the query - query = query_gen(**{ - **query, - **{key: [val] for key, val in zip(on, value)}, - "name": name, **kwargs - }) + query = query_gen( + **{ + **query, + **{key: [val] for key, val in zip(on, value)}, + "name": name, + **kwargs, + } + ) # Make sure it is a dict if is_dataclass(query): @@ -420,12 +470,13 @@ def _split_query(self, query, on, only=None, exclude=None, query_gen=None, ignor return queries - def generate_queries(self, - split: str, + def generate_queries( + self, + split: str, only: Optional[Sequence] = None, - exclude: Optional[Sequence] = None, + exclude: Optional[Sequence] = None, query_gen: Optional[Callable[[dict], dict]] = None, - **kwargs + **kwargs, ): """ Automatically generates queries based on the current options. @@ -454,7 +505,9 @@ def generate_queries(self, will split the PDOS on the different orbitals but will take only those that belong to carbon atoms. """ - return self._split_query({}, on=split, only=only, exclude=exclude, query_gen=query_gen, **kwargs) + return self._split_query( + {}, on=split, only=only, exclude=exclude, query_gen=query_gen, **kwargs + ) def sanitize_query(self, query): # Get the complete request and make sure it is a dict. @@ -462,25 +515,26 @@ def sanitize_query(self, query): if is_dataclass(query): query = asdict(query) - # Determine the reduce function from the "reduce" passed and the scale factor. + # Determine the reduce function from the "reduce" passed and the scale factor. def _reduce_func(arr, **kwargs): - reduce_ = query['reduce'] + reduce_ = query["reduce"] if isinstance(reduce_, str): reduce_ = getattr(np, reduce_) - if kwargs['axis'] == (): + if kwargs["axis"] == (): return arr return reduce_(arr, **kwargs) * query.get("scale", 1) - + # Finally, return the sanitized request, converting the request (contains "species", "n", "l", etc...) - # into a list of orbitals. + # into a list of orbitals. return { **query, "orbitals": self.get_orbitals(query), "reduce_func": _reduce_func, - **{k: gen(query) for k, gen in self.key_gens.items()} + **{k: gen(query) for k, gen in self.key_gens.items()}, } - + + def generate_orbital_queries( orb_manager: OrbitalQueriesManager, split: str, @@ -488,17 +542,29 @@ def generate_orbital_queries( exclude: Optional[Sequence] = None, query_gen: Optional[Callable[[dict], dict]] = None, ): - return orb_manager.generate_queries(split, only=only, exclude=exclude, query_gen=query_gen) + return orb_manager.generate_queries( + split, only=only, exclude=exclude, query_gen=query_gen + ) + -def reduce_orbital_data(orbital_data: Union[DataArray, Dataset], groups: Sequence[OrbitalGroup], geometry: Optional[Geometry] = None, - reduce_func: Callable = np.mean, spin_reduce: Optional[Callable] = None, orb_dim: str = "orb", spin_dim: str = "spin", - groups_dim: str = "group", sanitize_group: Union[Callable, OrbitalQueriesManager, None] = None, group_vars: Optional[Sequence[str]] = None, - drop_empty: bool = False, fill_empty: Any = 0. +def reduce_orbital_data( + orbital_data: Union[DataArray, Dataset], + groups: Sequence[OrbitalGroup], + geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, + spin_reduce: Optional[Callable] = None, + orb_dim: str = "orb", + spin_dim: str = "spin", + groups_dim: str = "group", + sanitize_group: Union[Callable, OrbitalQueriesManager, None] = None, + group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, + fill_empty: Any = 0.0, ) -> Union[DataArray, Dataset]: """Groups contributions of orbitals into a new dimension. Given an xarray object containing orbital information and the specification of groups of orbitals, this function - computes the total contribution for each group of orbitals. It therefore removes the orbitals dimension and + computes the total contribution for each group of orbitals. It therefore removes the orbitals dimension and creates a new one to account for the groups. It can also reduce spin in the same go if requested. In that case, groups can also specify particular spin components. @@ -516,7 +582,7 @@ def reduce_orbital_data(orbital_data: Union[DataArray, Dataset], groups: Sequenc afterwards in the ``parent`` attribute, under ``parent.geometry``. reduce_func : Callable, optional The function that will compute the reduction along the orbitals dimension once the selection is done. - This could be for example ``numpy.mean`` or ``numpy.sum``. + This could be for example ``numpy.mean`` or ``numpy.sum``. Notice that this will only be used in case the group specification doesn't specify a particular function in its "reduce_func" field, which will take preference. spin_reduce: Callable, optional @@ -548,13 +614,15 @@ def reduce_orbital_data(orbital_data: Union[DataArray, Dataset], groups: Sequenc if geometry is None: geometry = orbital_data.attrs.get("geometry") if geometry is None: - parent = orbital_data.attrs.get('parent') + parent = orbital_data.attrs.get("parent") if parent is not None: getattr(parent, "geometry") if sanitize_group is None: if geometry is not None: - sanitize_group = OrbitalQueriesManager(geometry=geometry, spin=orbital_data.attrs.get("spin", "")) + sanitize_group = OrbitalQueriesManager( + geometry=geometry, spin=orbital_data.attrs.get("spin", "") + ) else: sanitize_group = lambda x: x if isinstance(sanitize_group, OrbitalQueriesManager): @@ -567,101 +635,168 @@ def _sanitize_group(group): group = sanitize_group(group) if geometry is None: - orbitals = group.get('orbitals') + orbitals = group.get("orbitals") try: - group['orbitals'] = np.array(orbitals, dtype=int) + group["orbitals"] = np.array(orbitals, dtype=int) assert orbitals.ndim == 1 except: - raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't" - f" convert the provided atom selection ({orbitals}) to an array of integers.") + raise SislError( + "A geometry was neither provided nor found in the xarray object. Therefore we can't" + f" convert the provided atom selection ({orbitals}) to an array of integers." + ) else: group["orbitals"] = geometry._sanitize_orbs(group["orbitals"]) - group['selector'] = group['orbitals'] + group["selector"] = group["orbitals"] req_spin = group.get("spin") - if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords: + if ( + req_spin is None + and data_spin.is_polarized + and spin_dim in orbital_data.coords + ): if spin_reduce is None: - group['spin'] = original_spin_coord + group["spin"] = original_spin_coord else: - group['spin'] = [0, 1] + group["spin"] = [0, 1] - if (spin_reduce is not None or group.get("spin") is not None) and spin_dim in orbital_data.dims: - group['selector'] = (group['selector'], group.get('spin')) - group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) + if ( + spin_reduce is not None or group.get("spin") is not None + ) and spin_dim in orbital_data.dims: + group["selector"] = (group["selector"], group.get("spin")) + group["reduce_func"] = (group.get("reduce_func", reduce_func), spin_reduce) return group - + original_spin_coord = None if data_spin.is_polarized and spin_dim in orbital_data.coords: - if not isinstance(orbital_data, (DataArray, Dataset)): orbital_data = orbital_data._data - + original_spin_coord = orbital_data.coords[spin_dim].values - if "total" in orbital_data.coords['spin']: - spin_up = ((orbital_data.sel(spin="total") - orbital_data.sel(spin="z")) / 2).assign_coords(spin=0) - spin_down = ((orbital_data.sel(spin="total") + orbital_data.sel(spin="z")) / 2).assign_coords(spin=1) + if "total" in orbital_data.coords["spin"]: + spin_up = ( + (orbital_data.sel(spin="total") - orbital_data.sel(spin="z")) / 2 + ).assign_coords(spin=0) + spin_down = ( + (orbital_data.sel(spin="total") + orbital_data.sel(spin="z")) / 2 + ).assign_coords(spin=1) orbital_data = xarray.concat([orbital_data, spin_up, spin_down], "spin") else: total = orbital_data.sum(spin_dim).assign_coords(spin="total") - z = (orbital_data.sel(spin=0) - orbital_data.sel(spin=1)).assign_coords(spin="z") + z = (orbital_data.sel(spin=0) - orbital_data.sel(spin=1)).assign_coords( + spin="z" + ) orbital_data = xarray.concat([total, z, orbital_data], "spin") - + # If a reduction for spin was requested, then pass the two different functions to reduce # each coordinate. reduce_funcs = reduce_func reduce_dims = orb_dim - if (spin_reduce is not None or data_spin.is_polarized) and spin_dim in orbital_data.dims: + if ( + spin_reduce is not None or data_spin.is_polarized + ) and spin_dim in orbital_data.dims: reduce_funcs = (reduce_func, spin_reduce) reduce_dims = (orb_dim, spin_dim) return group_reduce( - data=orbital_data, groups=groups, reduce_dim=reduce_dims, reduce_func=reduce_funcs, - groups_dim=groups_dim, sanitize_group=_sanitize_group, group_vars=group_vars, - drop_empty=drop_empty, fill_empty=fill_empty + data=orbital_data, + groups=groups, + reduce_dim=reduce_dims, + reduce_func=reduce_funcs, + groups_dim=groups_dim, + sanitize_group=_sanitize_group, + group_vars=group_vars, + drop_empty=drop_empty, + fill_empty=fill_empty, ) -def get_orbital_queries_manager(obj, spin: Optional[str] = None, key_gens: Dict[str, Callable] = {}) -> OrbitalQueriesManager: + +def get_orbital_queries_manager( + obj, spin: Optional[str] = None, key_gens: Dict[str, Callable] = {} +) -> OrbitalQueriesManager: return OrbitalQueriesManager.new(obj, spin=spin, key_gens=key_gens) -def split_orbitals(orbital_data, on="species", only=None, exclude=None, geometry: Optional[Geometry] = None, - reduce_func: Callable = np.mean, spin_reduce: Optional[Callable] = None, orb_dim: str = "orb", spin_dim: str = "spin", - groups_dim: str = "group", group_vars: Optional[Sequence[str]] = None, - drop_empty: bool = False, fill_empty: Any = 0., **kwargs): +def split_orbitals( + orbital_data, + on="species", + only=None, + exclude=None, + geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, + spin_reduce: Optional[Callable] = None, + orb_dim: str = "orb", + spin_dim: str = "spin", + groups_dim: str = "group", + group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, + fill_empty: Any = 0.0, + **kwargs, +): if geometry is not None: orbital_data = orbital_data.copy() - orbital_data.attrs['geometry'] = geometry + orbital_data.attrs["geometry"] = geometry orbital_data = orbital_data.copy() - orb_manager = get_orbital_queries_manager(orbital_data, key_gens=kwargs.pop("key_gens", {})) + orb_manager = get_orbital_queries_manager( + orbital_data, key_gens=kwargs.pop("key_gens", {}) + ) - groups = orb_manager.generate_queries(split=on, only=only, exclude=exclude, **kwargs) + groups = orb_manager.generate_queries( + split=on, only=only, exclude=exclude, **kwargs + ) return reduce_orbital_data( - orbital_data, groups=groups, sanitize_group=orb_manager, reduce_func=reduce_func, spin_reduce=spin_reduce, - orb_dim=orb_dim, spin_dim=spin_dim, groups_dim=groups_dim, group_vars=group_vars, drop_empty=drop_empty, - fill_empty=fill_empty + orbital_data, + groups=groups, + sanitize_group=orb_manager, + reduce_func=reduce_func, + spin_reduce=spin_reduce, + orb_dim=orb_dim, + spin_dim=spin_dim, + groups_dim=groups_dim, + group_vars=group_vars, + drop_empty=drop_empty, + fill_empty=fill_empty, ) -def atom_data_from_orbital_data(orbital_data, atoms: AtomsArgument = None, request_kwargs: Dict = {}, geometry: Optional[Geometry] = None, - reduce_func: Callable = np.mean, spin_reduce: Optional[Callable] = None, orb_dim: str = "orb", spin_dim: str = "spin", - groups_dim: str = "atom", group_vars: Optional[Sequence[str]] = None, - drop_empty: bool = False, fill_empty: Any = 0., + +def atom_data_from_orbital_data( + orbital_data, + atoms: AtomsArgument = None, + request_kwargs: Dict = {}, + geometry: Optional[Geometry] = None, + reduce_func: Callable = np.mean, + spin_reduce: Optional[Callable] = None, + orb_dim: str = "orb", + spin_dim: str = "spin", + groups_dim: str = "atom", + group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, + fill_empty: Any = 0.0, ): request_kwargs["name"] = "$atoms" atom_data = split_orbitals( - orbital_data, on="atoms", only=atoms, reduce_func=reduce_func, spin_reduce=spin_reduce, - orb_dim=orb_dim, spin_dim=spin_dim, groups_dim=groups_dim, group_vars=group_vars, drop_empty=drop_empty, - fill_empty=fill_empty, **request_kwargs + orbital_data, + on="atoms", + only=atoms, + reduce_func=reduce_func, + spin_reduce=spin_reduce, + orb_dim=orb_dim, + spin_dim=spin_dim, + groups_dim=groups_dim, + group_vars=group_vars, + drop_empty=drop_empty, + fill_empty=fill_empty, + **request_kwargs, ) atom_data = atom_data.assign_coords(atom=atom_data.atom.astype(int)) - return atom_data \ No newline at end of file + return atom_data diff --git a/src/sisl/viz/processors/spin.py b/src/sisl/viz/processors/spin.py index 1b151d35ac..26e4b5dcd2 100644 --- a/src/sisl/viz/processors/spin.py +++ b/src/sisl/viz/processors/spin.py @@ -4,21 +4,30 @@ _options = { Spin.UNPOLARIZED: [], - Spin.POLARIZED: [{"label": "↑", "value": 0}, {"label": "↓", "value": 1}, - {"label": "Total", "value": "total"}, {"label": "Net z", "value": "z"}], - Spin.NONCOLINEAR: [{"label": val, "value": val} for val in ("total", "x", "y", "z")], - Spin.SPINORBIT: [{"label": val, "value": val} for val in ("total", "x", "y", "z")] + Spin.POLARIZED: [ + {"label": "↑", "value": 0}, + {"label": "↓", "value": 1}, + {"label": "Total", "value": "total"}, + {"label": "Net z", "value": "z"}, + ], + Spin.NONCOLINEAR: [ + {"label": val, "value": val} for val in ("total", "x", "y", "z") + ], + Spin.SPINORBIT: [{"label": val, "value": val} for val in ("total", "x", "y", "z")], } -def get_spin_options(spin: Union[Spin, str], only_if_polarized: bool = False) -> List[Literal[0, 1, "total", "x", "y", "z"]]: + +def get_spin_options( + spin: Union[Spin, str], only_if_polarized: bool = False +) -> List[Literal[0, 1, "total", "x", "y", "z"]]: """Returns the options for a given spin class. - + Parameters ---------- spin: sisl.Spin or str The spin class to get the options for. only_if_polarized: bool, optional - If set to `True`, non colinear spins will not have multiple options. + If set to `True`, non colinear spins will not have multiple options. """ spin = Spin(spin) @@ -27,4 +36,4 @@ def get_spin_options(spin: Union[Spin, str], only_if_polarized: bool = False) -> else: options_spin = spin - return [option['value'] for option in _options[options_spin.kind]] \ No newline at end of file + return [option["value"] for option in _options[options_spin.kind]] diff --git a/src/sisl/viz/processors/tests/test_axes.py b/src/sisl/viz/processors/tests/test_axes.py index ca963683df..08c673c737 100644 --- a/src/sisl/viz/processors/tests/test_axes.py +++ b/src/sisl/viz/processors/tests/test_axes.py @@ -12,7 +12,6 @@ def test_sanitize_axes(): - assert sanitize_axes(["x", "y", "z"]) == ["x", "y", "z"] assert sanitize_axes("xyz") == ["x", "y", "z"] assert sanitize_axes("abc") == ["a", "b", "c"] @@ -22,16 +21,16 @@ def test_sanitize_axes(): assert sanitize_axes("-x-y") == ["-x", "-y"] assert sanitize_axes("a-b") == ["a", "-b"] - axes = sanitize_axes([[0,1,2]]) + axes = sanitize_axes([[0, 1, 2]]) assert isinstance(axes[0], np.ndarray) assert axes[0].shape == (3,) - assert np.all(axes[0] == [0,1,2]) + assert np.all(axes[0] == [0, 1, 2]) with pytest.raises(ValueError): sanitize_axes([None]) -def test_axis_direction(): +def test_axis_direction(): assert np.allclose(axis_direction("x"), [1, 0, 0]) assert np.allclose(axis_direction("y"), [0, 1, 0]) assert np.allclose(axis_direction("z"), [0, 0, 1]) @@ -52,14 +51,14 @@ def test_axis_direction(): assert np.allclose(axis_direction("-b", cell), [-1, 0, 0]) assert np.allclose(axis_direction("-c", cell), [0, -1, 0]) -def test_axes_cross_product(): +def test_axes_cross_product(): assert np.allclose(axes_cross_product("x", "y"), [0, 0, 1]) assert np.allclose(axes_cross_product("y", "x"), [0, 0, -1]) assert np.allclose(axes_cross_product("-x", "y"), [0, 0, -1]) - assert np.allclose(axes_cross_product([1,0,0], [0,1,0]), [0, 0, 1]) - assert np.allclose(axes_cross_product([0,1,0], [1,0,0]), [0, 0, -1]) + assert np.allclose(axes_cross_product([1, 0, 0], [0, 1, 0]), [0, 0, 1]) + assert np.allclose(axes_cross_product([0, 1, 0], [1, 0, 0]), [0, 0, -1]) cell = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) @@ -67,8 +66,8 @@ def test_axes_cross_product(): assert np.allclose(axes_cross_product("c", "b", cell), [0, 0, -1]) assert np.allclose(axes_cross_product("-b", "c", cell), [0, 0, -1]) -def test_axis_title(): +def test_axis_title(): assert get_ax_title("title") == "title" assert get_ax_title("x") == "X axis [Ang]" @@ -81,9 +80,9 @@ def test_axis_title(): assert get_ax_title(None) == "" - assert get_ax_title(np.array([1,2,3])) == "[1 2 3]" + assert get_ax_title(np.array([1, 2, 3])) == "[1 2 3]" - def some_axis(): pass + def some_axis(): + pass assert get_ax_title(some_axis) == "some_axis" - diff --git a/src/sisl/viz/processors/tests/test_bands.py b/src/sisl/viz/processors/tests/test_bands.py index b836e4abc6..06edeb48c2 100644 --- a/src/sisl/viz/processors/tests/test_bands.py +++ b/src/sisl/viz/processors/tests/test_bands.py @@ -18,24 +18,30 @@ pytestmark = [pytest.mark.viz, pytest.mark.processors] -@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"]) +@pytest.fixture( + scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"] +) def spin(request): return Spin(request.param) + @pytest.fixture(scope="module") def gap(): return 2.5 + @pytest.fixture(scope="module") def bands_data(spin, gap): return BandsData.toy_example(spin=spin, gap=gap) + @pytest.fixture(scope="module", params=["x", "y"]) def E_axis(request): return request.param + def test_filter_bands(bands_data): - spin = bands_data.attrs['spin'] + spin = bands_data.attrs["spin"] # Check that it works without any arguments filtered_bands = filter_bands(bands_data) @@ -64,35 +70,34 @@ def test_filter_bands(bands_data): def test_calculate_gap(bands_data, gap): - spin = bands_data.attrs["spin"] gap_info = calculate_gap(bands_data) # Check that the gap value is correct - assert gap_info['gap'] == gap + assert gap_info["gap"] == gap # Check also that the position of the gap is in the information - assert isinstance(gap_info['k'], tuple) and len(gap_info['k']) == 2 + assert isinstance(gap_info["k"], tuple) and len(gap_info["k"]) == 2 VB = len(bands_data.band) // 2 - 1 - assert isinstance(gap_info['bands'], tuple) and len(gap_info['bands']) == 2 - assert gap_info['bands'][0] < gap_info['bands'][1] + assert isinstance(gap_info["bands"], tuple) and len(gap_info["bands"]) == 2 + assert gap_info["bands"][0] < gap_info["bands"][1] - assert isinstance(gap_info['spin'], tuple) and len(gap_info['spin']) == 2 + assert isinstance(gap_info["spin"], tuple) and len(gap_info["spin"]) == 2 if not spin.is_polarized: - assert gap_info['spin'] == (0, 0) + assert gap_info["spin"] == (0, 0) - assert isinstance(gap_info['Es'], tuple) and len(gap_info['Es']) == 2 - assert np.allclose(gap_info['Es'], (- gap / 2, gap / 2)) + assert isinstance(gap_info["Es"], tuple) and len(gap_info["Es"]) == 2 + assert np.allclose(gap_info["Es"], (-gap / 2, gap / 2)) -def test_sanitize_k(bands_data): +def test_sanitize_k(bands_data): assert sanitize_k(bands_data, "Gamma") == 0 assert sanitize_k(bands_data, "X") == 1 -def test_get_gap_coords(bands_data): +def test_get_gap_coords(bands_data): spin = bands_data.attrs["spin"] vb = len(bands_data.band) // 2 - 1 @@ -100,8 +105,9 @@ def test_get_gap_coords(bands_data): # We can get the gamma gap by specifying both origin and destination or # just origin. Check also Gamma to X. for to_k in ["Gamma", None, "X"]: - - k, E = get_gap_coords(bands_data, (vb, vb + 1), from_k="Gamma", to_k=to_k, spin=1) + k, E = get_gap_coords( + bands_data, (vb, vb + 1), from_k="Gamma", to_k=to_k, spin=1 + ) kval = 1 if to_k == "X" else 0 @@ -118,8 +124,8 @@ def test_get_gap_coords(bands_data): assert E[0] == bands_E.sel(band=vb, k=0) assert E[1] == bands_E.sel(band=vb + 1, k=kval) -def test_draw_gap(E_axis): +def test_draw_gap(E_axis): ks = (0, 0.5) Es = (0, 1) @@ -138,21 +144,27 @@ def test_draw_gap(E_axis): assert action_kwargs["name"] == "test" assert action_kwargs["line"]["color"] == "red" assert action_kwargs["marker"]["color"] == "red" - assert action_kwargs['x'] == x - assert action_kwargs['y'] == y + assert action_kwargs["x"] == x + assert action_kwargs["y"] == y + @pytest.mark.parametrize("display_gap", [True, False]) def test_draw_gaps(bands_data, E_axis, display_gap): - spin = bands_data.attrs["spin"] gap_info = calculate_gap(bands_data) # Run the function only to draw the minimum gap. gap_actions = draw_gaps( - bands_data, gap=display_gap, gap_info=gap_info, - gap_tol=0.3, gap_color="red", gap_marker={}, - direct_gaps_only=False, custom_gaps=[], E_axis=E_axis + bands_data, + gap=display_gap, + gap_info=gap_info, + gap_tol=0.3, + gap_color="red", + gap_marker={}, + direct_gaps_only=False, + custom_gaps=[], + E_axis=E_axis, ) assert isinstance(gap_actions, list) @@ -168,15 +180,21 @@ def test_draw_gaps(bands_data, E_axis, display_gap): # Now run the function with a custom gap. gap_actions = draw_gaps( - bands_data, gap=display_gap, gap_info=gap_info, - gap_tol=0.3, gap_color="red", gap_marker={}, - direct_gaps_only=False, - custom_gaps=[{"from": "Gamma", "to": "X", "color": "blue"}], - E_axis=E_axis + bands_data, + gap=display_gap, + gap_info=gap_info, + gap_tol=0.3, + gap_color="red", + gap_marker={}, + direct_gaps_only=False, + custom_gaps=[{"from": "Gamma", "to": "X", "color": "blue"}], + E_axis=E_axis, ) assert isinstance(gap_actions, list) - assert len(gap_actions) == (2 if display_gap else 1) + (1 if spin.is_polarized else 0) + assert len(gap_actions) == (2 if display_gap else 1) + ( + 1 if spin.is_polarized else 0 + ) # Check the minimum gap if display_gap: @@ -186,7 +204,7 @@ def test_draw_gaps(bands_data, E_axis, display_gap): action_kwargs = gap_actions[0]["kwargs"] assert action_kwargs["line"]["color"] == "red" assert action_kwargs["marker"]["color"] == "red" - + # Check the custom gap assert isinstance(gap_actions[-1], dict) assert gap_actions[-1]["method"] == "draw_line" @@ -194,16 +212,17 @@ def test_draw_gaps(bands_data, E_axis, display_gap): action_kwargs = gap_actions[-1]["kwargs"] assert action_kwargs["line"]["color"] == "blue" assert action_kwargs["marker"]["color"] == "blue" - assert action_kwargs['x' if E_axis == "y" else "y"] == (0, 1) + assert action_kwargs["x" if E_axis == "y" else "y"] == (0, 1) -def test_style_bands(bands_data): +def test_style_bands(bands_data): spin = bands_data.attrs["spin"] # Check basic styles styled_bands = style_bands( - bands_data, {"color": "red", "width": 3}, - spindown_style={"opacity": 0.5, "color": "blue"} + bands_data, + {"color": "red", "width": 3}, + spindown_style={"opacity": 0.5, "color": "blue"}, ) assert isinstance(styled_bands, xr.Dataset) @@ -223,20 +242,20 @@ def test_style_bands(bands_data): # Check function as style def color(data): return xr.DataArray( - np.where(data.band < 5, "red", "blue"), - coords=[("band", data.band.values)] + np.where(data.band < 5, "red", "blue"), coords=[("band", data.band.values)] ) - + styled_bands = style_bands( - bands_data, {"color": color, "width": 3}, - spindown_style={"opacity": 0.5, "color": "blue"} + bands_data, + {"color": color, "width": 3}, + spindown_style={"opacity": 0.5, "color": "blue"}, ) assert isinstance(styled_bands, xr.Dataset) - + for k in ("color", "width", "opacity"): assert k in styled_bands.data_vars - + assert "band" in styled_bands.color.coords if spin.is_polarized: bands_color = styled_bands.color.sel(spin=0) @@ -244,4 +263,3 @@ def color(data): else: bands_color = styled_bands.color assert np.all((styled_bands.band < 5) == (bands_color == "red")) - diff --git a/src/sisl/viz/processors/tests/test_cell.py b/src/sisl/viz/processors/tests/test_cell.py index f54e7861e0..4e62522eda 100644 --- a/src/sisl/viz/processors/tests/test_cell.py +++ b/src/sisl/viz/processors/tests/test_cell.py @@ -16,14 +16,13 @@ @pytest.fixture(scope="module", params=["numpy", "lattice"]) def Cell(request): - if request.param == "numpy": return np.array elif request.param == "lattice": return Lattice - -def test_cartesian_unordered(Cell): + +def test_cartesian_unordered(Cell): assert is_cartesian_unordered( Cell([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), ) @@ -36,39 +35,28 @@ def test_cartesian_unordered(Cell): Cell([[0, 2, 1], [1, 0, 0], [0, 1, 0]]), ) + def test_1D_cartesian(Cell): + assert is_1D_cartesian(Cell([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), "x") - assert is_1D_cartesian( - Cell([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - "x" - ) + assert is_1D_cartesian(Cell([[1, 0, 0], [0, 1, 1], [0, 0, 1]]), "x") - assert is_1D_cartesian( - Cell([[1, 0, 0], [0, 1, 1], [0, 0, 1]]), - "x" - ) + assert not is_1D_cartesian(Cell([[1, 0, 0], [1, 1, 0], [0, 0, 1]]), "x") - assert not is_1D_cartesian( - Cell([[1, 0, 0], [1, 1, 0], [0, 0, 1]]), - "x" - ) def test_infer_cell_axes(Cell): - assert infer_cell_axes( - Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), - axes=["x", "y", "z"] + Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), axes=["x", "y", "z"] ) == [1, 0, 2] assert infer_cell_axes( - Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), - axes=["b", "y"] + Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), axes=["b", "y"] ) == [1, 0] -def test_gen_cell_dataset(): +def test_gen_cell_dataset(): lattice = Lattice([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - + cell_dataset = gen_cell_dataset(lattice) assert isinstance(cell_dataset, xr.Dataset) @@ -80,11 +68,11 @@ def test_gen_cell_dataset(): assert cell_dataset.xyz.shape == (2, 2, 2, 3) assert np.all(cell_dataset.xyz.values == lattice.vertices()) + @pytest.mark.parametrize("mode", ["box", "axes", "other"]) def test_cell_to_lines(mode): - lattice = Lattice([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - + cell_dataset = gen_cell_dataset(lattice) if mode == "other": diff --git a/src/sisl/viz/processors/tests/test_coords.py b/src/sisl/viz/processors/tests/test_coords.py index 7082d57ad9..61e3e1761c 100644 --- a/src/sisl/viz/processors/tests/test_coords.py +++ b/src/sisl/viz/processors/tests/test_coords.py @@ -20,25 +20,24 @@ @pytest.fixture(scope="module", params=["numpy", "lattice"]) def Cell(request): - if request.param == "numpy": return np.array elif request.param == "lattice": return Lattice + @pytest.fixture(scope="module") def coords_dataset(): geometry = sisl.geom.bcc(2.93, "Au", False) return xr.Dataset( - {"xyz": (("atom", "axis"), geometry.xyz)}, - coords={"axis": [0,1,2]}, - attrs={"geometry": geometry} + {"xyz": (("atom", "axis"), geometry.xyz)}, + coords={"axis": [0, 1, 2]}, + attrs={"geometry": geometry}, ) - + def test_projected_1D_coords(Cell): - cell = Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) x, y, z = 3, -4, 2 @@ -60,9 +59,8 @@ def test_projected_1D_coords(Cell): projected = projected_1Dcoords(cell, coords, [x, 0, z]) assert np.allclose(projected, [[1]]) - -def test_projected_2D_coords(Cell): +def test_projected_2D_coords(Cell): cell = Cell([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) x, y, z = 3, -4, 2 @@ -90,21 +88,21 @@ def test_projected_2D_coords(Cell): projected = projected_2Dcoords(cell, coords, [x, y, 0], [0, 0, z]) assert np.allclose(projected, [[1, 1]]) + def test_coords_depth(coords_dataset): - depth = coords_depth(coords_dataset, ["x", "y"]) assert isinstance(depth, np.ndarray) assert np.allclose(depth, coords_dataset.xyz.sel(axis=2).values) depth = coords_depth(coords_dataset, ["y", "x"]) - assert np.allclose(depth, - coords_dataset.xyz.sel(axis=2).values) + assert np.allclose(depth, -coords_dataset.xyz.sel(axis=2).values) depth = coords_depth(coords_dataset, [[1, 0, 0], [0, 0, 1]]) - assert np.allclose(depth, - coords_dataset.xyz.sel(axis=1).values) + assert np.allclose(depth, -coords_dataset.xyz.sel(axis=1).values) + @pytest.mark.parametrize("center", [[0, 0, 0], [1, 1, 0]]) def test_sphere(center): - coords = sphere(center=center, r=3.5, vertices=15) assert isinstance(coords, dict) @@ -113,14 +111,16 @@ def test_sphere(center): assert "y" in coords assert "z" in coords - assert coords["x"].shape == coords["y"].shape == coords["z"].shape == (15 ** 2,) + assert coords["x"].shape == coords["y"].shape == coords["z"].shape == (15**2,) - R = np.linalg.norm(np.array([coords["x"], coords["y"], coords["z"]]).T - center, axis=1) + R = np.linalg.norm( + np.array([coords["x"], coords["y"], coords["z"]]).T - center, axis=1 + ) assert np.allclose(R, 3.5) -def test_projected_1D_data(coords_dataset): +def test_projected_1D_data(coords_dataset): # No data projected = projected_1D_data(coords_dataset, "y") assert isinstance(projected, xr.Dataset) @@ -133,24 +133,26 @@ def test_projected_1D_data(coords_dataset): projected = projected_1D_data(coords_dataset, "-y", dataaxis_1d=np.sin) assert isinstance(projected, xr.Dataset) assert "x" in projected.data_vars - assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert np.allclose(projected.x, -coords_dataset.xyz.sel(axis=1)) assert "y" in projected.data_vars - assert np.allclose(projected.y, np.sin(- coords_dataset.xyz.sel(axis=1))) + assert np.allclose(projected.y, np.sin(-coords_dataset.xyz.sel(axis=1))) # Data from array - projected = projected_1D_data(coords_dataset, "-y", dataaxis_1d=coords_dataset.xyz.sel(axis=2).values) + projected = projected_1D_data( + coords_dataset, "-y", dataaxis_1d=coords_dataset.xyz.sel(axis=2).values + ) assert isinstance(projected, xr.Dataset) assert "x" in projected.data_vars - assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert np.allclose(projected.x, -coords_dataset.xyz.sel(axis=1)) assert "y" in projected.data_vars assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=2)) -def test_projected_2D_data(coords_dataset): +def test_projected_2D_data(coords_dataset): projected = projected_2D_data(coords_dataset, "-y", "x") assert isinstance(projected, xr.Dataset) assert "x" in projected.data_vars - assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert np.allclose(projected.x, -coords_dataset.xyz.sel(axis=1)) assert "y" in projected.data_vars assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=0)) @@ -168,19 +170,19 @@ def test_projected_2D_data(coords_dataset): # Check that points are sorted by depth. assert np.all(np.diff(projected.depth) > 0) -def test_projected_3D_data(coords_dataset): +def test_projected_3D_data(coords_dataset): projected = projected_3D_data(coords_dataset) assert isinstance(projected, xr.Dataset) assert "x" in projected.data_vars - assert np.allclose(projected.x, coords_dataset.xyz.sel(axis=0)) + assert np.allclose(projected.x, coords_dataset.xyz.sel(axis=0)) assert "y" in projected.data_vars assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=1)) assert "z" in projected.data_vars assert np.allclose(projected.z, coords_dataset.xyz.sel(axis=2)) -def test_project_to_axes(coords_dataset): +def test_project_to_axes(coords_dataset): projected = project_to_axes(coords_dataset, ["z"], dataaxis_1d=4) assert isinstance(projected, xr.Dataset) assert "x" in projected.data_vars @@ -192,7 +194,7 @@ def test_project_to_axes(coords_dataset): projected = project_to_axes(coords_dataset, ["-y", "x"]) assert isinstance(projected, xr.Dataset) assert "x" in projected.data_vars - assert np.allclose(projected.x, - coords_dataset.xyz.sel(axis=1)) + assert np.allclose(projected.x, -coords_dataset.xyz.sel(axis=1)) assert "y" in projected.data_vars assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=0)) assert "z" not in projected.data_vars @@ -205,4 +207,3 @@ def test_project_to_axes(coords_dataset): assert np.allclose(projected.y, coords_dataset.xyz.sel(axis=1)) assert "z" in projected.data_vars assert np.allclose(projected.z, coords_dataset.xyz.sel(axis=2)) - diff --git a/src/sisl/viz/processors/tests/test_data.py b/src/sisl/viz/processors/tests/test_data.py index 1b2d2dbb1a..e7276cce63 100644 --- a/src/sisl/viz/processors/tests/test_data.py +++ b/src/sisl/viz/processors/tests/test_data.py @@ -13,12 +13,13 @@ def __init__(self, valid: bool = True): def sanity_check(self): assert self._data == True + class OtherData(Data): pass + @pytest.mark.parametrize("valid", [True, False]) def test_accept_data(valid): - data = FakeData(valid) # If the input is an instance of an invalid class @@ -35,9 +36,9 @@ def test_accept_data(valid): # Don't perform a sanity check on data assert accept_data(data, FakeData, check=False) is data + @pytest.mark.parametrize("valid", [True, False]) def test_extract_data(valid): - data = FakeData(valid) # If the input is an instance of an invalid class @@ -53,4 +54,3 @@ def test_extract_data(valid): # Don't perform a sanity check on data assert extract_data(data, FakeData, check=False) is data._data - diff --git a/src/sisl/viz/processors/tests/test_eigenstate.py b/src/sisl/viz/processors/tests/test_eigenstate.py index 4184100232..7625589b82 100644 --- a/src/sisl/viz/processors/tests/test_eigenstate.py +++ b/src/sisl/viz/processors/tests/test_eigenstate.py @@ -21,15 +21,16 @@ def k(request): elif request.param == "X": return (0.5, 0, 0) + @pytest.fixture(scope="module") def graphene(): - r = np.linspace(0, 3.5, 50) f = np.exp(-r) - orb = sisl.AtomicOrbital('2pzZ', (r, f)) + orb = sisl.AtomicOrbital("2pzZ", (r, f)) return sisl.geom.graphene(orthogonal=True, atoms=sisl.Atom(6, orb)) + @pytest.fixture(scope="module") def eigenstate(k, graphene): # Create a simple graphene tight binding Hamiltonian @@ -38,8 +39,8 @@ def eigenstate(k, graphene): return H.eigenstate(k=k) -def test_get_eigenstate(eigenstate, graphene): +def test_get_eigenstate(eigenstate, graphene): sel_eigenstate = get_eigenstate(eigenstate, 2) assert sel_eigenstate.state.shape == (1, graphene.no) @@ -54,8 +55,8 @@ def test_get_eigenstate(eigenstate, graphene): assert sel_eigenstate.state.shape == (1, graphene.no) assert np.allclose(sel_eigenstate.state, eigenstate.state[3]) -def test_eigenstate_geometry(eigenstate, graphene): +def test_eigenstate_geometry(eigenstate, graphene): # It should give us the geometry associated with the eigenstate assert eigenstate_geometry(eigenstate) is graphene @@ -63,12 +64,12 @@ def test_eigenstate_geometry(eigenstate, graphene): graphene_copy = graphene.copy() assert eigenstate_geometry(eigenstate, graphene_copy) is graphene_copy -def test_tile_if_k(eigenstate, graphene): +def test_tile_if_k(eigenstate, graphene): # If the eigenstate is calculated at gamma, we don't need to tile tiled_geometry = tile_if_k(graphene, (2, 2, 2), eigenstate) - if eigenstate.info["k"] == (0,0,0): + if eigenstate.info["k"] == (0, 0, 0): # If the eigenstate is calculated at gamma, we don't need to tile assert tiled_geometry is graphene elif eigenstate.info["k"] == (0.5, 0, 0): @@ -77,17 +78,17 @@ def test_tile_if_k(eigenstate, graphene): assert tiled_geometry is not graphene assert np.allclose(tiled_geometry.cell, graphene.cell * (2, 1, 1)) -def test_get_grid_nsc(eigenstate): +def test_get_grid_nsc(eigenstate): grid_nsc = get_grid_nsc((2, 2, 2), eigenstate) - if eigenstate.info["k"] == (0,0,0): + if eigenstate.info["k"] == (0, 0, 0): assert grid_nsc == (2, 2, 2) elif eigenstate.info["k"] == (0.5, 0, 0): assert grid_nsc == (1, 2, 2) -def test_create_wf_grid(eigenstate, graphene): +def test_create_wf_grid(eigenstate, graphene): new_graphene = graphene.copy() grid = create_wf_grid(eigenstate, grid_prec=0.2, geometry=new_graphene) @@ -95,7 +96,7 @@ def test_create_wf_grid(eigenstate, graphene): assert grid.geometry is new_graphene # Check that the datatype is correct - if eigenstate.info["k"] == (0,0,0): + if eigenstate.info["k"] == (0, 0, 0): assert grid.grid.dtype == np.float64 else: assert grid.grid.dtype == np.complex128 @@ -103,14 +104,14 @@ def test_create_wf_grid(eigenstate, graphene): # Check that the grid precision is right. assert np.allclose(np.linalg.norm(grid.dcell, axis=1), 0.2, atol=0.01) - provided_grid = sisl.Grid(0.2, geometry=new_graphene, dtype=np.float64) + provided_grid = sisl.Grid(0.2, geometry=new_graphene, dtype=np.float64) grid = create_wf_grid(eigenstate, grid=provided_grid) assert grid is provided_grid -def test_project_wavefunction(eigenstate, graphene): +def test_project_wavefunction(eigenstate, graphene): k = eigenstate.info["k"] grid = project_wavefunction(eigenstate[2], geometry=graphene) @@ -118,10 +119,10 @@ def test_project_wavefunction(eigenstate, graphene): assert isinstance(grid, sisl.Grid) # Check that the datatype is correct - if k == (0,0,0): + if k == (0, 0, 0): assert grid.grid.dtype == np.float64 else: - assert grid.grid.dtype == np.complex128 + assert grid.grid.dtype == np.complex128 # Check that the grid is not empty assert not np.allclose(grid.grid, 0) diff --git a/src/sisl/viz/processors/tests/test_geometry.py b/src/sisl/viz/processors/tests/test_geometry.py index afc448dd92..a792ed7898 100644 --- a/src/sisl/viz/processors/tests/test_geometry.py +++ b/src/sisl/viz/processors/tests/test_geometry.py @@ -26,15 +26,16 @@ def geometry(): return sisl.geom.bcc(2.93, "Au", True) + @pytest.fixture(scope="module") def coords_dataset(geometry): - return xr.Dataset( - {"xyz": (("atom", "axis"), geometry.xyz)}, - coords={"axis": [0,1,2]}, - attrs={"geometry": geometry} + {"xyz": (("atom", "axis"), geometry.xyz)}, + coords={"axis": [0, 1, 2]}, + attrs={"geometry": geometry}, ) + def test_tile_geometry(): geom = sisl.geom.graphene() @@ -42,6 +43,7 @@ def test_tile_geometry(): assert np.allclose(tiled_geometry.cell.T, geom.cell.T * (2, 3, 5)) + def test_find_all_bonds(): geom = sisl.geom.graphene() @@ -57,18 +59,18 @@ def test_find_all_bonds(): assert bonds.bonds.shape == (23, 2) # Now get bonds only for the unit cell - geom.set_nsc([1,1,1]) + geom.set_nsc([1, 1, 1]) bonds = find_all_bonds(geom, 1.5) assert bonds.bonds.shape == (1, 2) - assert np.all(bonds.bonds == (0,1)) + assert np.all(bonds.bonds == (0, 1)) # Run function with just one atom bonds = find_all_bonds(geom.sub(0), 1.5) -def test_get_atom_bonds(): - bonds = np.array([[0,1], [0,2], [1,2]]) +def test_get_atom_bonds(): + bonds = np.array([[0, 1], [0, 2], [1, 2]]) mask = get_atoms_bonds(bonds, [0], ret_mask=True) @@ -80,10 +82,10 @@ def test_get_atom_bonds(): assert isinstance(atom_bonds, np.ndarray) assert atom_bonds.shape == (2, 2) - assert np.all(atom_bonds == [[0,1], [0,2]]) + assert np.all(atom_bonds == [[0, 1], [0, 2]]) -def test_sanitize_atoms(): +def test_sanitize_atoms(): geom = sisl.geom.graphene() sanitized = sanitize_atoms(geom, 3) @@ -91,8 +93,8 @@ def test_sanitize_atoms(): assert len(sanitized) == 1 assert sanitized[0] == 3 -def test_data_sc(coords_dataset): +def test_data_sc(coords_dataset): assert "isc" not in coords_dataset.dims # First, check that not tiling works as expected @@ -107,10 +109,13 @@ def test_data_sc(coords_dataset): assert "isc" in tiled.dims assert len(tiled.isc) == 2 assert np.allclose(tiled.sel(isc=0).xyz, coords_dataset.xyz) - assert np.allclose(tiled.sel(isc=1).xyz, coords_dataset.xyz + coords_dataset.attrs["geometry"].cell[0]) + assert np.allclose( + tiled.sel(isc=1).xyz, + coords_dataset.xyz + coords_dataset.attrs["geometry"].cell[0], + ) -def test_stack_sc_data(coords_dataset): +def test_stack_sc_data(coords_dataset): tiled = tile_data_sc(coords_dataset, nsc=(3, 3, 1)) assert "isc" in tiled.dims @@ -121,6 +126,7 @@ def test_stack_sc_data(coords_dataset): assert "sc_atom" in stacked.dims assert len(stacked.sc_atom) == 9 * len(coords_dataset.atom) + @pytest.mark.parametrize("data_type", [list, dict]) def test_parse_atoms_style_empty(data_type): g = sisl.geom.graphene() @@ -134,6 +140,7 @@ def test_parse_atoms_style_empty(data_type): for data_var in styles.data_vars: assert len(styles[data_var].shape) == 0 + @pytest.mark.parametrize("data_type", [list, dict]) def test_parse_atoms_style_single_values(data_type): g = sisl.geom.graphene() @@ -157,8 +164,8 @@ def test_parse_atoms_style_single_values(data_type): elif data_var == "size": assert styles[data_var].values == 14 -def test_add_xyz_to_dataset(geometry): +def test_add_xyz_to_dataset(geometry): parsed_atoms_style = parse_atoms_style(geometry, {"color": "green", "size": 14}) atoms_dataset = add_xyz_to_dataset(parsed_atoms_style) @@ -169,36 +176,37 @@ def test_add_xyz_to_dataset(geometry): assert atoms_dataset.xyz.shape == (geometry.na, 3) assert np.allclose(atoms_dataset.xyz, geometry.xyz) + @pytest.mark.parametrize("data_type", [list, dict]) def test_sanitize_arrows_empty(data_type): g = sisl.geom.graphene() - arrows = sanitize_arrows(g, data_type(), atoms=None, ndim=3, axes="xyz" ) + arrows = sanitize_arrows(g, data_type(), atoms=None, ndim=3, axes="xyz") assert isinstance(arrows, list) assert len(arrows) == 0 -def test_sanitize_arrows(): - data = np.array([[0,0,0],[1,1,1]]) +def test_sanitize_arrows(): + data = np.array([[0, 0, 0], [1, 1, 1]]) g = sisl.geom.graphene() unparsed = [{"data": data}] - arrows = sanitize_arrows(g, unparsed, atoms=None, ndim=3, axes="xyz" ) + arrows = sanitize_arrows(g, unparsed, atoms=None, ndim=3, axes="xyz") assert isinstance(arrows, list) - assert np.allclose(arrows[0]['data'], data) - - arrows_from_dict = sanitize_arrows(g, unparsed[0], atoms=None, ndim=3, axes="xyz" ) + assert np.allclose(arrows[0]["data"], data) + + arrows_from_dict = sanitize_arrows(g, unparsed[0], atoms=None, ndim=3, axes="xyz") assert isinstance(arrows_from_dict, list) for k, v in arrows[0].items(): if not isinstance(v, np.ndarray): assert arrows[0][k] == arrows_from_dict[0][k] -def test_style_bonds(geometry): +def test_style_bonds(geometry): bonds = find_all_bonds(geometry, 1.5) # Test no styles @@ -231,7 +239,9 @@ def some_property(geometry, bonds): for k in ("color", "width", "opacity"): assert k in styled_bonds.data_vars, f"Missing {k}" assert styled_bonds[k].shape == (len(bonds.bonds),), f"Wrong shape for {k}" - assert np.all(styled_bonds[k].values == np.arange(len(bonds.bonds))), f"Wrong value for {k}" + assert np.all( + styled_bonds[k].values == np.arange(len(bonds.bonds)) + ), f"Wrong value for {k}" # Test scale styles = {"color": some_property, "width": some_property, "opacity": some_property} @@ -242,12 +252,16 @@ def some_property(geometry, bonds): assert k in styled_bonds.data_vars, f"Missing {k}" assert styled_bonds[k].shape == (len(bonds.bonds),), f"Wrong shape for {k}" if k == "width": - assert np.all(styled_bonds[k].values == 2 * np.arange(len(bonds.bonds))), f"Wrong value for {k}" + assert np.all( + styled_bonds[k].values == 2 * np.arange(len(bonds.bonds)) + ), f"Wrong value for {k}" else: - assert np.all(styled_bonds[k].values == np.arange(len(bonds.bonds))), f"Wrong value for {k}" + assert np.all( + styled_bonds[k].values == np.arange(len(bonds.bonds)) + ), f"Wrong value for {k}" -def test_add_xyz_to_bonds_dataset(geometry): +def test_add_xyz_to_bonds_dataset(geometry): bonds = find_all_bonds(geometry, 1.5) xyz_bonds = add_xyz_to_bonds_dataset(bonds) @@ -257,8 +271,8 @@ def test_add_xyz_to_bonds_dataset(geometry): assert xyz_bonds.xyz.shape == (len(bonds.bonds), 2, 3) assert np.allclose(xyz_bonds.xyz[:, 0], geometry.xyz[bonds.bonds[:, 0]]) -def test_sanitize_bonds_selection(geometry): +def test_sanitize_bonds_selection(geometry): bonds = find_all_bonds(geometry, 1.5) # No selection @@ -279,8 +293,8 @@ def test_sanitize_bonds_selection(geometry): assert isinstance(bonds_sel, np.ndarray) assert (bonds.sel(bond_index=bonds_sel) == 0).any("bond_atom").all("bond_index") -def test_bonds_to_lines(geometry): +def test_bonds_to_lines(geometry): bonds = find_all_bonds(geometry, 1.5) xyz_bonds = add_xyz_to_bonds_dataset(bonds) @@ -299,4 +313,3 @@ def test_bonds_to_lines(geometry): assert isinstance(bond_lines, xr.Dataset) assert "point_index" in bond_lines.dims assert len(bond_lines.point_index) == len(xyz_bonds.bond_index) * 11 - diff --git a/src/sisl/viz/processors/tests/test_grid.py b/src/sisl/viz/processors/tests/test_grid.py index eb425a8e43..f128d45495 100644 --- a/src/sisl/viz/processors/tests/test_grid.py +++ b/src/sisl/viz/processors/tests/test_grid.py @@ -29,39 +29,52 @@ def skewed(request) -> bool: return request.param == "skewed" -real_part = np.arange(10*10*10).reshape(10,10,10) -imag_part = np.arange(10*10*10).reshape(10,10,10) + 1 + +real_part = np.arange(10 * 10 * 10).reshape(10, 10, 10) +imag_part = np.arange(10 * 10 * 10).reshape(10, 10, 10) + 1 + @pytest.fixture(scope="module") def origin(): return [1, 2, 3] + @pytest.fixture(scope="module") def grid(origin, skewed) -> Grid: - if skewed: lattice = Lattice([[3, 0, 0], [1, -1, 0], [0, 0, 3]], origin=origin) else: lattice = Lattice([[3, 0, 0], [0, 2, 0], [0, 0, 6]], origin=origin) - + geometry = Geometry([[0, 0, 0]], lattice=lattice) grid = Grid([10, 10, 10], geometry=geometry, dtype=np.complex128) - grid.grid[:] = ( real_part + imag_part * 1j).reshape(10, 10, 10) + grid.grid[:] = (real_part + imag_part * 1j).reshape(10, 10, 10) return grid -def test_get_grid_representation(grid): +def test_get_grid_representation(grid): assert np.allclose(get_grid_representation(grid, "real").grid, real_part) assert np.allclose(get_grid_representation(grid, "imag").grid, imag_part) - assert np.allclose(get_grid_representation(grid, "mod").grid, np.sqrt(real_part**2 + imag_part**2)) - assert np.allclose(get_grid_representation(grid, "phase").grid, np.arctan2(imag_part, real_part)) - assert np.allclose(get_grid_representation(grid, "rad_phase").grid, np.arctan2(imag_part, real_part)) - assert np.allclose(get_grid_representation(grid, "deg_phase").grid, np.arctan2(imag_part, real_part) * 180 / np.pi) + assert np.allclose( + get_grid_representation(grid, "mod").grid, + np.sqrt(real_part**2 + imag_part**2), + ) + assert np.allclose( + get_grid_representation(grid, "phase").grid, np.arctan2(imag_part, real_part) + ) + assert np.allclose( + get_grid_representation(grid, "rad_phase").grid, + np.arctan2(imag_part, real_part), + ) + assert np.allclose( + get_grid_representation(grid, "deg_phase").grid, + np.arctan2(imag_part, real_part) * 180 / np.pi, + ) -def test_tile_grid(grid): +def test_tile_grid(grid): # By default it is not tiled tiled = tile_grid(grid) assert isinstance(tiled, Grid) @@ -73,12 +86,12 @@ def test_tile_grid(grid): tiled = tile_grid(grid, (1, 2, 1)) assert isinstance(tiled, Grid) assert tiled.shape == (grid.shape[0], grid.shape[1] * 2, grid.shape[2]) - assert np.allclose(tiled.grid[:, :grid.shape[1]], grid.grid) - assert np.allclose(tiled.grid[:, grid.shape[1]:], grid.grid) + assert np.allclose(tiled.grid[:, : grid.shape[1]], grid.grid) + assert np.allclose(tiled.grid[:, grid.shape[1] :], grid.grid) assert np.allclose(tiled.origin, grid.origin) -def test_transform_grid_cell(grid, skewed): +def test_transform_grid_cell(grid, skewed): # Convert to a cartesian cell new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10)) @@ -101,10 +114,10 @@ def test_transform_grid_cell(grid, skewed): for i in range(3): n = new_grid.lattice.cell[i] / directions[i] assert np.allclose(n, n[0]) - + + @pytest.mark.parametrize("interp", [1, 2]) def test_orthogonalize_grid(grid, interp, skewed): - ort_grid = orthogonalize_grid(grid, interp=(interp, interp, interp)) assert ort_grid.shape == (10, 10, 10) if interp == 1 else (20, 20, 20) @@ -121,16 +134,18 @@ def test_orthogonalize_grid(grid, interp, skewed): assert not np.allclose(ort_grid.grid, 0) assert np.allclose(ort_grid.origin, grid.origin) == (not skewed and interp == 1) -def test_should_transform_grid_cell_plotting(grid, skewed): +def test_should_transform_grid_cell_plotting(grid, skewed): assert should_transform_grid_cell_plotting(grid, axes=["x", "y"]) == skewed assert should_transform_grid_cell_plotting(grid, axes=["z"]) == False + @pytest.mark.parametrize("interp", [1, 2]) def test_orthogonalize_grid_if_needed(grid, skewed, interp): - # Orthogonalize the skewed cell, since it is xy skewed. - ort_grid = orthogonalize_grid_if_needed(grid, axes=["x", "y"], interp=(interp, interp, interp)) + ort_grid = orthogonalize_grid_if_needed( + grid, axes=["x", "y"], interp=(interp, interp, interp) + ) assert ort_grid.shape == (10, 10, 10) if interp == 1 else (20, 20, 20) assert ort_grid.lattice.is_cartesian() @@ -147,18 +162,20 @@ def test_orthogonalize_grid_if_needed(grid, skewed, interp): assert np.allclose(ort_grid.origin, grid.origin) == (not skewed) # Do not orthogonalize the skewed cell, since it is not z skewed. - ort_grid = orthogonalize_grid_if_needed(grid, axes=["z"], interp=(interp, interp, interp)) + ort_grid = orthogonalize_grid_if_needed( + grid, axes=["z"], interp=(interp, interp, interp) + ) assert ort_grid.shape == (10, 10, 10) if interp == 1 else (20, 20, 20) - + if skewed: assert not ort_grid.lattice.is_cartesian() assert np.allclose(ort_grid.lattice.cell, grid.lattice.cell) assert np.allclose(ort_grid.grid, grid.grid) -def test_apply_transforms(grid): +def test_apply_transforms(grid): # Apply a function transf = apply_transforms(grid, transforms=[np.sqrt]) assert np.allclose(transf.grid, np.sqrt(grid.grid)) @@ -174,13 +191,10 @@ def test_apply_transforms(grid): assert np.allclose(transf.grid, np.sqrt(np.angle(grid.grid))) assert np.allclose(transf.origin, grid.origin) + @pytest.mark.parametrize("reduce_method", ["sum", "mean"]) def test_reduce_grid(grid, reduce_method): - - reduce_func = { - "sum": np.sum, - "mean": np.mean - }[reduce_method] + reduce_func = {"sum": np.sum, "mean": np.mean}[reduce_method] reduced = reduce_grid(grid, reduce_method, keep_axes=[0, 1]) @@ -188,12 +202,12 @@ def test_reduce_grid(grid, reduce_method): assert np.allclose(reduced.grid[:, :, 0], reduce_func(grid.grid, axis=2)) assert np.allclose(reduced.origin, grid.origin) + @pytest.mark.parametrize("direction", ["x", "y", "z"]) def test_sub_grid(grid, skewed, direction): - coord_ax = "xyz".index(direction) kwargs = {f"{direction}_range": (0.5, 1.5)} - + expected_origin = grid.origin.copy() if skewed and direction != "z": @@ -205,140 +219,148 @@ def test_sub_grid(grid, skewed, direction): # Check that the lattice has been reduced to contain the requested range, # taking into account that the bounds of the range might not be exactly # on the grid points. - assert 1 + sub.dcell[:, coord_ax].sum()*2 >= sub.lattice.cell[:, coord_ax].sum() >= 1 - sub.dcell[:, coord_ax].sum()*2 + assert ( + 1 + sub.dcell[:, coord_ax].sum() * 2 + >= sub.lattice.cell[:, coord_ax].sum() + >= 1 - sub.dcell[:, coord_ax].sum() * 2 + ) expected_origin[coord_ax] += 0.5 - + assert np.allclose(sub.origin, expected_origin) -def test_interpolate_grid(grid): +def test_interpolate_grid(grid): interp = interpolate_grid(grid, (20, 20, 20)) # Check that the shape has been augmented assert np.all(interp.shape == np.array((20, 20, 20)) * grid.shape) - + # The integral over the grid should be the same (or very similar) assert (grid.grid.sum() * 20**3 - interp.grid.sum()) < 1e-3 -def test_grid_geometry(grid): +def test_grid_geometry(grid): assert grid_geometry(grid) is grid.geometry geom_copy = grid.geometry.copy() assert grid_geometry(grid, geom_copy) is geom_copy -def test_get_grid_axes(grid, skewed): - assert get_grid_axes(grid, ['x', 'y', 'z']) == [0, 1, 2] +def test_get_grid_axes(grid, skewed): + assert get_grid_axes(grid, ["x", "y", "z"]) == [0, 1, 2] # This function doesn't care about what the axes are in 3D - assert get_grid_axes(grid, ['y', '-x', 'z']) == [0, 1, 2] + assert get_grid_axes(grid, ["y", "-x", "z"]) == [0, 1, 2] if skewed: with pytest.raises(ValueError): - get_grid_axes(grid, ['x', 'y']) + get_grid_axes(grid, ["x", "y"]) else: - assert get_grid_axes(grid, ['x', 'y']) == [0, 1] - assert get_grid_axes(grid, ['y', 'x']) == [1, 0] + assert get_grid_axes(grid, ["x", "y"]) == [0, 1] + assert get_grid_axes(grid, ["y", "x"]) == [1, 0] -def test_get_ax_vals(grid, skewed, origin): +def test_get_ax_vals(grid, skewed, origin): r = get_ax_vals(grid, "x", nsc=(1, 1, 1)) assert isinstance(r, np.ndarray) - assert r.shape == (grid.shape[0], ) + assert r.shape == (grid.shape[0],) if not skewed: assert r[0] == origin[0] - assert abs(r[-1] - (origin[0] + grid.lattice.cell[0, 0] - grid.dcell[0, 0])) < 1e-3 + assert ( + abs(r[-1] - (origin[0] + grid.lattice.cell[0, 0] - grid.dcell[0, 0])) < 1e-3 + ) r = get_ax_vals(grid, "a", nsc=(2, 1, 1)) assert isinstance(r, np.ndarray) - assert r.shape == (grid.shape[0], ) + assert r.shape == (grid.shape[0],) assert r[0] == 0 assert abs(r[-1] - 2) < 1e-3 -def test_get_offset(grid, origin): +def test_get_offset(grid, origin): assert get_offset(grid, "x") == origin[0] assert get_offset(grid, "b") == 0 assert get_offset(grid, 2) == 0 + def test_grid_to_dataarray(grid, skewed): # Test 1D av_grid = grid.average(0).average(1) - arr = grid_to_dataarray(av_grid, ['z'], [2], nsc=(1,1,1)) + arr = grid_to_dataarray(av_grid, ["z"], [2], nsc=(1, 1, 1)) assert isinstance(arr, xr.DataArray) assert len(arr.coords) == 1 assert "x" in arr.coords - assert arr.x.shape == (grid.shape[2], ) + assert arr.x.shape == (grid.shape[2],) assert np.allclose(arr.values, av_grid.grid[0, 0, :]) if skewed: return - + # Test 2D av_grid = grid.average(0) - arr = grid_to_dataarray(av_grid, ['y', 'z'], [1, 2], nsc=(1,1,1)) + arr = grid_to_dataarray(av_grid, ["y", "z"], [1, 2], nsc=(1, 1, 1)) assert isinstance(arr, xr.DataArray) assert len(arr.coords) == 2 assert "x" in arr.coords - assert arr.x.shape == (grid.shape[1], ) + assert arr.x.shape == (grid.shape[1],) assert "y" in arr.coords - assert arr.y.shape == (grid.shape[2], ) - + assert arr.y.shape == (grid.shape[2],) + assert np.allclose(arr.values, av_grid.grid[0, :, :]) # Test 2D with unordered axes av_grid = grid.average(0) - arr = grid_to_dataarray(av_grid, ['z', 'y'], [2, 1], nsc=(1,1,1)) + arr = grid_to_dataarray(av_grid, ["z", "y"], [2, 1], nsc=(1, 1, 1)) assert isinstance(arr, xr.DataArray) assert len(arr.coords) == 2 assert "x" in arr.coords - assert arr.x.shape == (grid.shape[2], ) + assert arr.x.shape == (grid.shape[2],) assert "y" in arr.coords - assert arr.y.shape == (grid.shape[1], ) - + assert arr.y.shape == (grid.shape[1],) + assert np.allclose(arr.values, av_grid.grid[0, :, :].T) # Test 3D av_grid = grid - arr = grid_to_dataarray(av_grid, ['x', 'y', 'z'], [0, 1, 2], nsc=(1,1,1)) + arr = grid_to_dataarray(av_grid, ["x", "y", "z"], [0, 1, 2], nsc=(1, 1, 1)) assert isinstance(arr, xr.DataArray) assert len(arr.coords) == 3 assert "x" in arr.coords - assert arr.x.shape == (grid.shape[0], ) + assert arr.x.shape == (grid.shape[0],) assert "y" in arr.coords - assert arr.y.shape == (grid.shape[1], ) + assert arr.y.shape == (grid.shape[1],) assert "z" in arr.coords - assert arr.z.shape == (grid.shape[2], ) - + assert arr.z.shape == (grid.shape[2],) + assert np.allclose(arr.values, av_grid.grid) + def test_get_isos(grid, skewed): pytest.importorskip("skimage") if skewed: return - + # Test isocontours (2D) - arr = grid_to_dataarray(grid.average(2), ['x', 'y'], [0, 1, 2], nsc=(1,1,1)) + arr = grid_to_dataarray(grid.average(2), ["x", "y"], [0, 1, 2], nsc=(1, 1, 1)) assert get_isos(arr, []) == [] - contours = get_isos(arr, [{'frac': 0.5}]) + contours = get_isos(arr, [{"frac": 0.5}]) assert isinstance(contours, list) assert len(contours) == 1 @@ -350,7 +372,7 @@ def test_get_isos(grid, skewed): assert "z" not in contours[0] # Test isosurfaces (3D) - arr = grid_to_dataarray(grid, ['x', 'y', 'z'], [0, 1, 2], nsc=(1,1,1)) + arr = grid_to_dataarray(grid, ["x", "y", "z"], [0, 1, 2], nsc=(1, 1, 1)) surfs = get_isos(arr, []) @@ -374,7 +396,7 @@ def test_get_isos(grid, skewed): assert surfs[0]["faces"].dtype == np.int32 assert surfs[0]["faces"].shape[1] == 3 - surfs = get_isos(arr, [{'val': 3, "color": "red", "opacity": 0.5, "name": "test"}]) + surfs = get_isos(arr, [{"val": 3, "color": "red", "opacity": 0.5, "name": "test"}]) assert isinstance(surfs, list) assert len(surfs) == 1 @@ -393,4 +415,3 @@ def test_get_isos(grid, skewed): assert isinstance(surfs[0]["faces"], np.ndarray) assert surfs[0]["faces"].dtype == np.int32 assert surfs[0]["faces"].shape[1] == 3 - diff --git a/src/sisl/viz/processors/tests/test_groupreduce.py b/src/sisl/viz/processors/tests/test_groupreduce.py index d894bf286f..a197237513 100644 --- a/src/sisl/viz/processors/tests/test_groupreduce.py +++ b/src/sisl/viz/processors/tests/test_groupreduce.py @@ -9,85 +9,119 @@ @pytest.fixture(scope="module") def dataarray(): - return xr.DataArray([ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - [10, 11, 12], - ], coords=[("x", [0,1,2,3]), ("y", [0,1,2])], name="vals" + return xr.DataArray( + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + ], + coords=[("x", [0, 1, 2, 3]), ("y", [0, 1, 2])], + name="vals", ) + @pytest.fixture(scope="module") def dataset(): - arr = xr.DataArray([ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - [10, 11, 12], - ], coords=[("x", [0,1,2,3]), ("y", [0,1,2])], name="vals" + arr = xr.DataArray( + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + ], + coords=[("x", [0, 1, 2, 3]), ("y", [0, 1, 2])], + name="vals", ) arr2 = arr * 2 return xr.Dataset({"vals": arr, "double": arr2}) -def test_dataarray(dataarray): - new = group_reduce(dataarray, [{"selector": [0,1]}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection") +def test_dataarray(dataarray): + new = group_reduce( + dataarray, + [{"selector": [0, 1]}, {"selector": [2, 3]}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims assert "selection" in new.dims assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] assert new.sel(selection=0).sum() == 1 + 2 + 3 + 4 + 5 + 6 assert new.sel(selection=1).sum() == 7 + 8 + 9 + 10 + 11 + 12 -def test_dataarray_multidim(dataarray): - new = group_reduce(dataarray, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], reduce_dim=("x", "y"), reduce_func=np.sum, groups_dim="selection") +def test_dataarray_multidim(dataarray): + new = group_reduce( + dataarray, + [{"selector": ([0, 1], [0, 1])}, {"selector": ([2, 3], [0, 1])}], + reduce_dim=("x", "y"), + reduce_func=np.sum, + groups_dim="selection", + ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims assert "y" not in new.dims assert "selection" in new.dims assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] assert new.sel(selection=0).sum() == 1 + 2 + 4 + 5 assert new.sel(selection=1).sum() == 7 + 8 + 10 + 11 -def test_dataarray_multidim_multireduce(dataarray): - new = group_reduce(dataarray, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], reduce_dim=("x", "y"), - reduce_func=(np.sum, np.mean), groups_dim="selection") +def test_dataarray_multidim_multireduce(dataarray): + new = group_reduce( + dataarray, + [{"selector": ([0, 1], [0, 1])}, {"selector": ([2, 3], [0, 1])}], + reduce_dim=("x", "y"), + reduce_func=(np.sum, np.mean), + groups_dim="selection", + ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims assert "y" not in new.dims assert "selection" in new.dims assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] assert new.sel(selection=0).sum() == (1 + 4) / 2 + (2 + 5) / 2 assert new.sel(selection=1).sum() == (7 + 10) / 2 + (8 + 11) / 2 + def test_dataarray_sangroup(dataarray): # We use sanitize group to simply set all selectors to [0,1] - new = group_reduce(dataarray, [{"selector": [0,1]}, {"selector": [2, 3]}], - reduce_dim="x", reduce_func=np.sum, groups_dim="selection", - sanitize_group=lambda group: {**group, "selector": [0,1]} + new = group_reduce( + dataarray, + [{"selector": [0, 1]}, {"selector": [2, 3]}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + sanitize_group=lambda group: {**group, "selector": [0, 1]}, ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims assert "selection" in new.dims assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] assert new.sel(selection=0).sum() == 1 + 2 + 3 + 4 + 5 + 6 assert new.sel(selection=1).sum() == 1 + 2 + 3 + 4 + 5 + 6 -def test_dataarray_names(dataarray): - new = group_reduce(dataarray, [{"selector": [0,1], "name": "first"}, {"selector": [2, 3], "name": "second"}], - reduce_dim="x", reduce_func=np.sum, groups_dim="selection") +def test_dataarray_names(dataarray): + new = group_reduce( + dataarray, + [{"selector": [0, 1], "name": "first"}, {"selector": [2, 3], "name": "second"}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims @@ -97,14 +131,18 @@ def test_dataarray_names(dataarray): assert new.sel(selection="first").sum() == 1 + 2 + 3 + 4 + 5 + 6 assert new.sel(selection="second").sum() == 7 + 8 + 9 + 10 + 11 + 12 -def test_dataarray_groupvars(dataarray): - new = group_reduce(dataarray, +def test_dataarray_groupvars(dataarray): + new = group_reduce( + dataarray, [ - {"selector": [0,1], "name": "first", "color": "red", "size": 3}, - {"selector": [2, 3], "name": "second", "color": "blue", "size": 4} - ], - reduce_dim="x", reduce_func=np.sum, groups_dim="selection", group_vars=["color", "size"] + {"selector": [0, 1], "name": "first", "color": "red", "size": 3}, + {"selector": [2, 3], "name": "second", "color": "blue", "size": 4}, + ], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + group_vars=["color", "size"], ) assert isinstance(new, xr.Dataset) @@ -125,10 +163,18 @@ def test_dataarray_groupvars(dataarray): assert list(k_data.coords["selection"]) == ["first", "second"] assert list(k_data) == vals + @pytest.mark.parametrize("drop", [True, False]) def test_dataarray_empty_selector(dataarray, drop): - - new = group_reduce(dataarray, [{"selector": []}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection", drop_empty=drop, fill_empty=0.) + new = group_reduce( + dataarray, + [{"selector": []}, {"selector": [2, 3]}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + drop_empty=drop, + fill_empty=0.0, + ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims @@ -138,37 +184,48 @@ def test_dataarray_empty_selector(dataarray, drop): assert list(new.coords["selection"]) == [1] else: assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] if not drop: - assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=0).sum() == 0.0 assert new.sel(selection=1).sum() == 7 + 8 + 9 + 10 + 11 + 12 + def test_dataarray_empty_selector_0d(): """When reducing an array along its only dimension, you get a 0d array. - + This was creating an error when filling empty selections. This test ensures that it doesn' happen again """ new = group_reduce( - xr.DataArray([1,2,3], coords=[("x", [0,1,2])]), - [{"selector": []}, {"selector": [1, 2]}], reduce_dim="x", reduce_func=np.sum, - groups_dim="selection", drop_empty=False, fill_empty=0. + xr.DataArray([1, 2, 3], coords=[("x", [0, 1, 2])]), + [{"selector": []}, {"selector": [1, 2]}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + drop_empty=False, + fill_empty=0.0, ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims assert "selection" in new.dims assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] - assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=0).sum() == 0.0 assert new.sel(selection=1).sum() == 2 + 3 -def test_dataset(dataset): - new = group_reduce(dataset, [{"selector": [0,1]}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection") +def test_dataset(dataset): + new = group_reduce( + dataset, + [{"selector": [0, 1]}, {"selector": [2, 3]}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + ) assert isinstance(new, xr.Dataset) assert "x" not in new.dims @@ -181,10 +238,15 @@ def test_dataset(dataset): assert new.sel(selection=0).sum() == (1 + 2 + 3 + 4 + 5 + 6) * 3 assert new.sel(selection=1).sum() == (7 + 8 + 9 + 10 + 11 + 12) * 3 -def test_dataset_multidim(dataset): - new = group_reduce(dataset, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], - reduce_dim=("x", "y"), reduce_func=np.sum, groups_dim="selection") +def test_dataset_multidim(dataset): + new = group_reduce( + dataset, + [{"selector": ([0, 1], [0, 1])}, {"selector": ([2, 3], [0, 1])}], + reduce_dim=("x", "y"), + reduce_func=np.sum, + groups_dim="selection", + ) assert isinstance(new, xr.Dataset) assert "x" not in new.dims @@ -198,10 +260,15 @@ def test_dataset_multidim(dataset): assert new.sel(selection=0).sum() == (1 + 2 + 4 + 5) * 3 assert new.sel(selection=1).sum() == (7 + 8 + 10 + 11) * 3 -def test_dataset_multidim_multireduce(dataset): - new = group_reduce(dataset, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], - reduce_dim=("x", "y"), reduce_func=(np.sum, np.mean), groups_dim="selection") +def test_dataset_multidim_multireduce(dataset): + new = group_reduce( + dataset, + [{"selector": ([0, 1], [0, 1])}, {"selector": ([2, 3], [0, 1])}], + reduce_dim=("x", "y"), + reduce_func=(np.sum, np.mean), + groups_dim="selection", + ) assert isinstance(new, xr.Dataset) assert "x" not in new.dims @@ -212,26 +279,37 @@ def test_dataset_multidim_multireduce(dataset): assert "vals" in new assert "double" in new - assert new.sel(selection=0).sum() == ((1 + 4) / 2 + (2 + 5) / 2 ) * 3 + assert new.sel(selection=0).sum() == ((1 + 4) / 2 + (2 + 5) / 2) * 3 assert new.sel(selection=1).sum() == ((7 + 10) / 2 + (8 + 11) / 2) * 3 -def test_dataarray_multidim_multireduce(dataarray): - new = group_reduce(dataarray, [{"selector": ([0,1], [0,1])}, {"selector": ([2, 3], [0,1])}], reduce_dim=("x", "y"), - reduce_func=(np.sum, np.mean), groups_dim="selection") +def test_dataarray_multidim_multireduce(dataarray): + new = group_reduce( + dataarray, + [{"selector": ([0, 1], [0, 1])}, {"selector": ([2, 3], [0, 1])}], + reduce_dim=("x", "y"), + reduce_func=(np.sum, np.mean), + groups_dim="selection", + ) assert isinstance(new, xr.DataArray) assert "x" not in new.dims assert "y" not in new.dims assert "selection" in new.dims assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] - assert new.sel(selection=0).sum() == (1 + 4) / 2 + (2 + 5) / 2 + assert list(new.coords["selection"]) == [0, 1] + assert new.sel(selection=0).sum() == (1 + 4) / 2 + (2 + 5) / 2 assert new.sel(selection=1).sum() == (7 + 10) / 2 + (8 + 11) / 2 -def test_dataset_names(dataset): - new = group_reduce(dataset, [{"selector": [0,1], "name": "first"}, {"selector": [2, 3], "name": "second"}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection") +def test_dataset_names(dataset): + new = group_reduce( + dataset, + [{"selector": [0, 1], "name": "first"}, {"selector": [2, 3], "name": "second"}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + ) assert isinstance(new, xr.Dataset) assert "x" not in new.dims @@ -244,14 +322,18 @@ def test_dataset_names(dataset): assert new.sel(selection="first").sum() == (1 + 2 + 3 + 4 + 5 + 6) * 3 assert new.sel(selection="second").sum() == (7 + 8 + 9 + 10 + 11 + 12) * 3 -def test_dataset_groupvars(dataset): - new = group_reduce(dataset, +def test_dataset_groupvars(dataset): + new = group_reduce( + dataset, [ - {"selector": [0,1], "name": "first", "color": "red", "size": 3}, - {"selector": [2, 3], "name": "second", "color": "blue", "size": 4} - ], - reduce_dim="x", reduce_func=np.sum, groups_dim="selection", group_vars=["color", "size"] + {"selector": [0, 1], "name": "first", "color": "red", "size": 3}, + {"selector": [2, 3], "name": "second", "color": "blue", "size": 4}, + ], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + group_vars=["color", "size"], ) assert isinstance(new, xr.Dataset) @@ -273,10 +355,18 @@ def test_dataset_groupvars(dataset): assert list(k_data.coords["selection"]) == ["first", "second"] assert list(k_data) == vals + @pytest.mark.parametrize("drop", [True, False]) def test_dataset_empty_selector(dataset, drop): - - new = group_reduce(dataset, [{"selector": []}, {"selector": [2, 3]}], reduce_dim="x", reduce_func=np.sum, groups_dim="selection", drop_empty=drop, fill_empty=0.) + new = group_reduce( + dataset, + [{"selector": []}, {"selector": [2, 3]}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + drop_empty=drop, + fill_empty=0.0, + ) assert isinstance(new, xr.Dataset) assert "x" not in new.dims @@ -286,34 +376,39 @@ def test_dataset_empty_selector(dataset, drop): assert list(new.coords["selection"]) == [1] else: assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] assert "vals" in new assert "double" in new if not drop: - assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=0).sum() == 0.0 assert new.sel(selection=1).sum() == (7 + 8 + 9 + 10 + 11 + 12) * 3 + def test_dataaset_empty_selector_0d(): """When reducing an array along its only dimension, you get a 0d array. - + This was creating an error when filling empty selections. This test ensures that it doesn' happen again """ - dataset = xr.Dataset({"vals": (["x"], [1,2,3])}) + dataset = xr.Dataset({"vals": (["x"], [1, 2, 3])}) new = group_reduce( dataset, - [{"selector": []}, {"selector": [1, 2]}], reduce_dim="x", reduce_func=np.sum, - groups_dim="selection", drop_empty=False, fill_empty=0. + [{"selector": []}, {"selector": [1, 2]}], + reduce_dim="x", + reduce_func=np.sum, + groups_dim="selection", + drop_empty=False, + fill_empty=0.0, ) assert isinstance(new, xr.Dataset) assert "x" not in new.dims assert "selection" in new.dims assert len(new.coords["selection"]) == 2 - assert list(new.coords["selection"]) == [0,1] + assert list(new.coords["selection"]) == [0, 1] assert "vals" in new - assert new.sel(selection=0).sum() == 0. + assert new.sel(selection=0).sum() == 0.0 assert new.sel(selection=1).sum() == 2 + 3 diff --git a/src/sisl/viz/processors/tests/test_logic.py b/src/sisl/viz/processors/tests/test_logic.py index 23239dcc9f..b09117b0dd 100644 --- a/src/sisl/viz/processors/tests/test_logic.py +++ b/src/sisl/viz/processors/tests/test_logic.py @@ -6,15 +6,14 @@ def test_swap(): - assert swap(1, (1, 2)) == 2 assert swap(2, (1, 2)) == 1 with pytest.raises(ValueError): swap(3, (1, 2)) -def test_matches(): +def test_matches(): assert matches(1, 1) == True assert matches(1, 2) == False @@ -27,8 +26,7 @@ def test_matches(): assert matches(1, 1, ret_false="b") == True assert matches(1, 2, ret_false="b") == "b" -def test_switch(): +def test_switch(): assert switch(True, "a", "b") == "a" assert switch(False, "a", "b") == "b" - diff --git a/src/sisl/viz/processors/tests/test_math.py b/src/sisl/viz/processors/tests/test_math.py index 7bf5acfb63..c71c29e1ad 100644 --- a/src/sisl/viz/processors/tests/test_math.py +++ b/src/sisl/viz/processors/tests/test_math.py @@ -7,10 +7,8 @@ def test_normalize(): - data = [0, 1, 2] assert np.allclose(normalize(data), [0, 0.5, 1]) assert np.allclose(normalize(data, vmin=-1, vmax=1), [-1, 0, 1]) - diff --git a/src/sisl/viz/processors/tests/test_orbital.py b/src/sisl/viz/processors/tests/test_orbital.py index 94b5b766bc..e4a3bb05ad 100644 --- a/src/sisl/viz/processors/tests/test_orbital.py +++ b/src/sisl/viz/processors/tests/test_orbital.py @@ -17,11 +17,15 @@ @pytest.fixture(scope="module") def geometry(): - orbs = [ - AtomicOrbital("2sZ1"), AtomicOrbital("2sZ2"), - AtomicOrbital("2pxZ1"), AtomicOrbital("2pyZ1"), AtomicOrbital("2pzZ1"), - AtomicOrbital("2pxZ2"), AtomicOrbital("2pyZ2"), AtomicOrbital("2pzZ2"), + AtomicOrbital("2sZ1"), + AtomicOrbital("2sZ2"), + AtomicOrbital("2pxZ1"), + AtomicOrbital("2pyZ1"), + AtomicOrbital("2pzZ1"), + AtomicOrbital("2pxZ2"), + AtomicOrbital("2pyZ2"), + AtomicOrbital("2pzZ2"), ] atoms = [ @@ -30,26 +34,30 @@ def geometry(): ] return sisl.geom.graphene(atoms=atoms) -@pytest.fixture(scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"]) + +@pytest.fixture( + scope="module", params=["unpolarized", "polarized", "noncolinear", "spinorbit"] +) def spin(request): return sisl.Spin(request.param) + @pytest.fixture(scope="module") def orb_manager(geometry, spin): return OrbitalQueriesManager(geometry, spin=spin) -def test_get_orbitals(orb_manager, geometry: Geometry): +def test_get_orbitals(orb_manager, geometry: Geometry): orbs = orb_manager.get_orbitals({"atoms": [0]}) assert len(orbs) == geometry.atoms.atom[0].no assert np.all(orbs == np.arange(geometry.atoms.atom[0].no)) orbs = orb_manager.get_orbitals({"orbitals": [0, 1]}) assert len(orbs) == 2 - assert np.all(orbs == np.array([0,1])) + assert np.all(orbs == np.array([0, 1])) -def test_get_atoms(orb_manager, geometry: Geometry): +def test_get_atoms(orb_manager, geometry: Geometry): ats = orb_manager.get_atoms({"atoms": [0]}) assert len(ats) == 1 assert ats[0] == 0 @@ -65,7 +73,6 @@ def test_get_atoms(orb_manager, geometry: Geometry): def test_split(orb_manager, geometry: Geometry): - # Check that it can split over species queries = orb_manager.generate_queries(split="species") @@ -87,7 +94,6 @@ def test_split(orb_manager, geometry: Geometry): assert len(queries) == geometry.na for i_atom in range(geometry.na): - query = queries[i_atom] assert isinstance(query, dict), f"Query is not a dict: {query}" @@ -103,7 +109,6 @@ def test_split(orb_manager, geometry: Geometry): assert len(queries) != 0 for query in queries: - assert isinstance(query, dict), f"Query is not a dict: {query}" assert "l" in query, f"Query does not have l: {query}" @@ -111,6 +116,7 @@ def test_split(orb_manager, geometry: Geometry): assert len(query["l"]) == 1 assert isinstance(query["l"][0], int) + def test_double_split(orb_manager): # Check that it can split over two things at the same time queries = orb_manager.generate_queries(split="l+m") @@ -132,51 +138,55 @@ def test_double_split(orb_manager): assert abs(query["m"][0]) <= query["l"][0] -def test_split_only(orb_manager, geometry): - queries = orb_manager.generate_queries(split="species", only=[geometry.atoms.atom[0].tag]) +def test_split_only(orb_manager, geometry): + queries = orb_manager.generate_queries( + split="species", only=[geometry.atoms.atom[0].tag] + ) assert len(queries) == 1 - assert queries[0]['species'] == [geometry.atoms.atom[0].tag] + assert queries[0]["species"] == [geometry.atoms.atom[0].tag] -def test_split_exclude(orb_manager, geometry): - queries = orb_manager.generate_queries(split="species", exclude=[geometry.atoms.atom[0].tag]) +def test_split_exclude(orb_manager, geometry): + queries = orb_manager.generate_queries( + split="species", exclude=[geometry.atoms.atom[0].tag] + ) assert len(queries) == geometry.atoms.nspecie - 1 - assert geometry.atoms.atom[0].tag not in [query['species'][0] for query in queries] + assert geometry.atoms.atom[0].tag not in [query["species"][0] for query in queries] -def test_constrained_split(orb_manager, geometry): +def test_constrained_split(orb_manager, geometry): queries = orb_manager.generate_queries(split="species", atoms=[0]) assert len(queries) == 1 - assert queries[0]['species'] == [geometry.atoms.atom[0].tag] + assert queries[0]["species"] == [geometry.atoms.atom[0].tag] -def test_split_name(orb_manager, geometry): +def test_split_name(orb_manager, geometry): queries = orb_manager.generate_queries(split="species", name="Tag: $species") assert len(queries) == geometry.atoms.nspecie for query in queries: assert "name" in query, f"Query does not have name: {query}" - assert query['name'] == f"Tag: {query['species'][0]}" + assert query["name"] == f"Tag: {query['species'][0]}" -def test_sanitize_query(orb_manager, geometry): +def test_sanitize_query(orb_manager, geometry): san_query = orb_manager.sanitize_query({"atoms": [0]}) atom_orbitals = geometry.atoms.atom[0].orbitals - assert len(san_query['orbitals']) == len(atom_orbitals) - assert np.all(san_query['orbitals'] == np.arange(len(atom_orbitals))) + assert len(san_query["orbitals"]) == len(atom_orbitals) + assert np.all(san_query["orbitals"] == np.arange(len(atom_orbitals))) -def test_reduce_orbital_data(geometry, spin): +def test_reduce_orbital_data(geometry, spin): data = PDOSData.toy_example(geometry=geometry, spin=spin)._data - reduced = reduce_orbital_data(data, [{"name": "all"}] ) + reduced = reduce_orbital_data(data, [{"name": "all"}]) assert isinstance(reduced, xr.DataArray) @@ -186,7 +196,7 @@ def test_reduce_orbital_data(geometry, spin): else: assert dim in reduced.dims assert len(data[dim]) == len(reduced[dim]) - + assert "group" in reduced.dims assert len(reduced.group) == 1 assert reduced.group[0] == "all" @@ -196,20 +206,20 @@ def test_reduce_orbital_data(geometry, spin): data_no_geometry.attrs.pop("geometry") with pytest.raises(SislError): - reduced = reduce_orbital_data(data_no_geometry, [{"name": "all"}] ) + reduced = reduce_orbital_data(data_no_geometry, [{"name": "all"}]) -def test_reduce_orbital_data_spin(geometry, spin): +def test_reduce_orbital_data_spin(geometry, spin): data = PDOSData.toy_example(geometry=geometry, spin=spin)._data if spin.is_polarized: - sel_total = reduce_orbital_data(data, [{"name": "all", "spin": "total"}] ) + sel_total = reduce_orbital_data(data, [{"name": "all", "spin": "total"}]) red_total = reduce_orbital_data(data, [{"name": "all"}], spin_reduce=np.sum) assert np.allclose(sel_total.values, red_total.values) -def test_atom_data_from_orbital_data(geometry: Geometry, spin): +def test_atom_data_from_orbital_data(geometry: Geometry, spin): data = PDOSData.toy_example(geometry=geometry, spin=spin)._data atom_data = atom_data_from_orbital_data(data, geometry) @@ -222,7 +232,7 @@ def test_atom_data_from_orbital_data(geometry: Geometry, spin): else: assert dim in atom_data.dims assert len(data[dim]) == len(atom_data[dim]) - + assert "atom" in atom_data.dims assert len(atom_data.atom) == geometry.na assert np.all(atom_data.atom == np.arange(geometry.na)) diff --git a/src/sisl/viz/processors/tests/test_sci_groupreduce.py b/src/sisl/viz/processors/tests/test_sci_groupreduce.py index 064c1eb5bb..2c672050e7 100644 --- a/src/sisl/viz/processors/tests/test_sci_groupreduce.py +++ b/src/sisl/viz/processors/tests/test_sci_groupreduce.py @@ -11,44 +11,57 @@ @pytest.fixture(scope="module") def atom_x(): - geom = sisl.geom.graphene().tile(20,1).tile(20,0) + geom = sisl.geom.graphene().tile(20, 1).tile(20, 0) + + return xr.DataArray( + geom.xyz[:, 0], coords=[("atom", range(geom.na))], attrs={"geometry": geom} + ) - return xr.DataArray(geom.xyz[:, 0], coords=[("atom", range(geom.na))], attrs={"geometry": geom}) @pytest.fixture(scope="module") def atom_xyz(): - geom = sisl.geom.graphene().tile(20,1).tile(20,0) + geom = sisl.geom.graphene().tile(20, 1).tile(20, 0) + + return xr.Dataset( + { + "x": xr.DataArray(geom.xyz[:, 0], coords=[("atom", range(geom.na))]), + "y": xr.DataArray(geom.xyz[:, 1], coords=[("atom", range(geom.na))]), + "z": xr.DataArray(geom.xyz[:, 2], coords=[("atom", range(geom.na))]), + }, + attrs={"geometry": geom}, + ) - return xr.Dataset({ - "x": xr.DataArray(geom.xyz[:, 0], coords=[("atom", range(geom.na))]), - "y": xr.DataArray(geom.xyz[:, 1], coords=[("atom", range(geom.na))]), - "z": xr.DataArray(geom.xyz[:, 2], coords=[("atom", range(geom.na))]), - }, attrs={"geometry": geom}) def test_reduce_atom_dataarray(atom_x): - grouped = reduce_atom_data( atom_x, - [{"atoms": [0,1], "name": "first"}, {"atoms": [5, 6], "name": "second"}], - reduce_func=np.sum, groups_dim="group" + [{"atoms": [0, 1], "name": "first"}, {"atoms": [5, 6], "name": "second"}], + reduce_func=np.sum, + groups_dim="group", ) assert isinstance(grouped, xr.DataArray) assert float(grouped.sel(group="first")) == np.sum(atom_x.values[0:2].sum()) assert float(grouped.sel(group="second")) == np.sum(atom_x.values[5:7].sum()) + def test_reduce_atom_dataarray_cat(atom_x): """We test that the atoms field is correctly sanitized using the geometry attached to the xarray object.""" grouped = reduce_atom_data( atom_x, - [{"atoms": {"x": (0, 10)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], - reduce_func=np.max, groups_dim="group" + [ + {"atoms": {"x": (0, 10)}, "name": "first"}, + {"atoms": {"x": (10, None)}, "name": "second"}, + ], + reduce_func=np.max, + groups_dim="group", ) assert isinstance(grouped, xr.DataArray) assert float(grouped.sel(group="first")) <= 10 assert float(grouped.sel(group="second")) == atom_x.max() + def test_reduce_atom_cat_nogeom(atom_x): """We test that the atoms field is correctly sanitized using the geometry attached to the xarray object.""" atom_x = atom_x.copy() @@ -61,27 +74,41 @@ def test_reduce_atom_cat_nogeom(atom_x): with pytest.raises(Exception): grouped = reduce_atom_data( atom_x, - [{"atoms": {"x": (0, 10)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], - reduce_func=np.max, groups_dim="group" + [ + {"atoms": {"x": (0, 10)}, "name": "first"}, + {"atoms": {"x": (10, None)}, "name": "second"}, + ], + reduce_func=np.max, + groups_dim="group", ) # If we explicitly pass the geometry it should again be able to sanitize the atoms grouped = reduce_atom_data( atom_x, - [{"atoms": {"x": (0, 10)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], - geometry=geometry, reduce_func=np.max, groups_dim="group" + [ + {"atoms": {"x": (0, 10)}, "name": "first"}, + {"atoms": {"x": (10, None)}, "name": "second"}, + ], + geometry=geometry, + reduce_func=np.max, + groups_dim="group", ) assert isinstance(grouped, xr.DataArray) assert float(grouped.sel(group="first")) <= 10 assert float(grouped.sel(group="second")) == atom_x.max() + def test_reduce_atom_dataset_cat(atom_xyz): """We test that the atoms field is correctly sanitized using the geometry attached to the xarray object.""" grouped = reduce_atom_data( atom_xyz, - [{"atoms": {"x": (0, 10), "y": (1, 3)}, "name": "first"}, {"atoms": {"x": (10, None)}, "name": "second"}], - reduce_func=np.max, groups_dim="group" + [ + {"atoms": {"x": (0, 10), "y": (1, 3)}, "name": "first"}, + {"atoms": {"x": (10, None)}, "name": "second"}, + ], + reduce_func=np.max, + groups_dim="group", ) assert isinstance(grouped, xr.Dataset) @@ -89,4 +116,3 @@ def test_reduce_atom_dataset_cat(atom_xyz): assert float(grouped.sel(group="second").x) == atom_xyz.x.max() assert float(grouped.sel(group="first").y) <= 3 assert float(grouped.sel(group="second").y) == atom_xyz.y.max() - diff --git a/src/sisl/viz/processors/tests/test_spin.py b/src/sisl/viz/processors/tests/test_spin.py index 1c08550569..01cef227c5 100644 --- a/src/sisl/viz/processors/tests/test_spin.py +++ b/src/sisl/viz/processors/tests/test_spin.py @@ -6,7 +6,6 @@ def test_get_spin_options(): - # Unpolarized spin assert len(get_spin_options("unpolarized")) == 0 @@ -28,4 +27,3 @@ def test_get_spin_options(): options = get_spin_options("noncolinear", only_if_polarized=True) assert len(options) == 0 - diff --git a/src/sisl/viz/processors/wavefunction.py b/src/sisl/viz/processors/wavefunction.py index 407a2d0404..18d073a929 100644 --- a/src/sisl/viz/processors/wavefunction.py +++ b/src/sisl/viz/processors/wavefunction.py @@ -21,7 +21,7 @@ def get_ith_eigenstate(eigenstate: EigenstateElectron, i: int): This is useful because an EigenstateElectron contains all the eigenstates. Sometimes a post-processing tool calculates only a subset of eigenstates, and this is what you have inside the EigenstateElectron. - therefore getting eigenstate[0] does not mean that + therefore getting eigenstate[0] does not mean that Parameters ---------- @@ -29,7 +29,7 @@ def get_ith_eigenstate(eigenstate: EigenstateElectron, i: int): The object containing all eigenstates. i : int The index of the eigenstate to get. - + Returns ---------- EigenstateElectron @@ -39,22 +39,34 @@ def get_ith_eigenstate(eigenstate: EigenstateElectron, i: int): if "index" in eigenstate.info: wf_i = np.nonzero(eigenstate.info["index"] == i)[0] if len(wf_i) == 0: - raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {eigenstate.info['index']}.") + raise ValueError( + f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {eigenstate.info['index']}." + ) wf_i = wf_i[0] else: max_index = len(eigenstate) if i > max_index: - raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}].") + raise ValueError( + f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}]." + ) wf_i = i return eigenstate[wf_i] + class WavefunctionDataNode(GridDataNode): ... + @WavefunctionDataNode.register -def eigenstate_wf(eigenstate: EigenstateElectron, i: int, grid: Optional[Grid] = None, geometry: Optional[Geometry] = None, - k = [0,0,0], grid_prec: float = 0.2, spin: Optional[Spin] = None +def eigenstate_wf( + eigenstate: EigenstateElectron, + i: int, + grid: Optional[Grid] = None, + geometry: Optional[Geometry] = None, + k=[0, 0, 0], + grid_prec: float = 0.2, + spin: Optional[Spin] = None, ): if geometry is None: if isinstance(eigenstate.parent, Geometry): @@ -62,7 +74,9 @@ def eigenstate_wf(eigenstate: EigenstateElectron, i: int, grid: Optional[Grid] = else: geometry = getattr(eigenstate.parent, "geometry", None) if geometry is None: - raise ValueError('No geometry was provided and we need it the basis orbitals to build the wavefunctions from the coefficients!') + raise ValueError( + "No geometry was provided and we need it the basis orbitals to build the wavefunctions from the coefficients!" + ) if spin is None: spin = getattr(eigenstate.parent, "spin", Spin()) @@ -79,28 +93,39 @@ def eigenstate_wf(eigenstate: EigenstateElectron, i: int, grid: Optional[Grid] = wf_state = get_ith_eigenstate(eigenstate, i) # Ensure we are dealing with the R gauge - wf_state.change_gauge('R') + wf_state.change_gauge("R") # Finally, insert the wavefunction values into the grid. - wavefunction( - wf_state.state, grid, geometry=geometry, - k=k, spinor=0, spin=spin - ) + wavefunction(wf_state.state, grid, geometry=geometry, k=k, spinor=0, spin=spin) return grid @WavefunctionDataNode.register -def hamiltonian_wf(H: Hamiltonian, i: int, grid: Optional[Grid] = None, geometry: Optional[Geometry] = None, - k = [0,0,0], grid_prec: float = 0.2, spin: int = 0 +def hamiltonian_wf( + H: Hamiltonian, + i: int, + grid: Optional[Grid] = None, + geometry: Optional[Geometry] = None, + k=[0, 0, 0], + grid_prec: float = 0.2, + spin: int = 0, ): eigenstate = H.eigenstate(k=k, spin=spin) return eigenstate_wf(eigenstate, i, grid, geometry, k, grid_prec, spin) + @WavefunctionDataNode.register -def wfsx_wf(fdf, wfsx_file, i: int, grid: Optional[Grid] = None, geometry: Optional[Geometry] = None, - k = [0,0,0], grid_prec: float = 0.2, spin: int = 0 +def wfsx_wf( + fdf, + wfsx_file, + i: int, + grid: Optional[Grid] = None, + geometry: Optional[Geometry] = None, + k=[0, 0, 0], + grid_prec: float = 0.2, + spin: int = 0, ): fdf = FileDataSIESTA(path=fdf) geometry = fdf.read_geometry(output=True) @@ -121,4 +146,4 @@ def wfsx_wf(fdf, wfsx_file, i: int, grid: Optional[Grid] = None, geometry: Optio # We have not found it. raise ValueError(f"A state with k={k} was not found in file {wfsx.file}.") - return eigenstate_wf(eigenstate, i, grid, geometry, k, grid_prec) \ No newline at end of file + return eigenstate_wf(eigenstate, i, grid, geometry, k, grid_prec) diff --git a/src/sisl/viz/processors/xarray.py b/src/sisl/viz/processors/xarray.py index 0c6744a146..04f414d66c 100644 --- a/src/sisl/viz/processors/xarray.py +++ b/src/sisl/viz/processors/xarray.py @@ -24,21 +24,29 @@ def __getattr__(self, key): def __dir__(self): return dir(self._data) + class Group(TypedDict, total=False): name: str selector: Any reduce_func: Optional[Callable] ... -def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[Group], - reduce_dim: Union[str, Tuple[str, ...]], reduce_func: Union[Callable, Tuple[Callable, ...]] = np.mean, groups_dim: str = "group", - sanitize_group: Callable = lambda x: x, group_vars: Optional[Sequence[str]] = None, - drop_empty: bool = False, fill_empty: Any = 0. + +def group_reduce( + data: Union[DataArray, Dataset, XarrayData], + groups: Sequence[Group], + reduce_dim: Union[str, Tuple[str, ...]], + reduce_func: Union[Callable, Tuple[Callable, ...]] = np.mean, + groups_dim: str = "group", + sanitize_group: Callable = lambda x: x, + group_vars: Optional[Sequence[str]] = None, + drop_empty: bool = False, + fill_empty: Any = 0.0, ) -> Union[DataArray, Dataset]: """Groups contributions of orbitals into a new dimension. Given an xarray object containing orbital information and the specification of groups of orbitals, this function - computes the total contribution for each group of orbitals. It therefore removes the orbitals dimension and + computes the total contribution for each group of orbitals. It therefore removes the orbitals dimension and creates a new one to account for the groups. It can also reduce spin in the same go if requested. In that case, groups can also specify particular spin components. @@ -51,7 +59,7 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G A sequence containing the specifications for each group of orbitals. See ``Group``. reduce_func : Callable or tuple of Callable, optional The function that will compute the reduction along the reduced dimension once the selection is done. - This could be for example ``numpy.mean`` or ``numpy.sum``. + This could be for example ``numpy.mean`` or ``numpy.sum``. Notice that this will only be used in case the group specification doesn't specify a particular function in its "reduce_func" field, which will take preference. If ``reduce_dim`` is a tuple, this can also be a tuple to indicate different reducing methods for each @@ -84,7 +92,7 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G return data.drop_dims(reduce_dim) else: raise ValueError("Must specify at least one group.") - + input_is_dataarray = isinstance(data, DataArray) if not isinstance(reduce_dim, tuple): @@ -96,12 +104,14 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G for i_group, group in enumerate(groups): group = sanitize_group(group) # Get the orbitals of the group - selector = group['selector'] + selector = group["selector"] if not isinstance(selector, tuple): selector = (selector,) # Select the data we are interested in - group_vals = data.sel(**{dim: sel for dim, sel in zip(reduce_dim, selector) if sel is not None}) + group_vals = data.sel( + **{dim: sel for dim, sel in zip(reduce_dim, selector) if sel is not None} + ) empty = False for dim in reduce_dim: @@ -119,7 +129,9 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G if drop_empty: continue else: - group_vals = data.isel({dim: 0 for dim in reduce_dim}, drop=True).copy(deep=True) + group_vals = data.isel({dim: 0 for dim in reduce_dim}, drop=True).copy( + deep=True + ) if input_is_dataarray: group_vals[...] = fill_empty else: @@ -132,13 +144,15 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G if not isinstance(reduce_funcs, tuple): reduce_funcs = tuple([reduce_funcs] * len(reduce_dim)) for dim, func in zip(reduce_dim, reduce_funcs): - if func is None or (reduce_dim not in group_vals.dims and reduce_dim in group_vals.coords): + if func is None or ( + reduce_dim not in group_vals.dims + and reduce_dim in group_vals.coords + ): continue group_vals = group_vals.reduce(func, dim=dim) - # Assign the name to this group and add it to the list of groups. - name = group.get('name') or i_group + name = group.get("name") or i_group names.append(name) if input_is_dataarray: group_vals.name = name @@ -148,7 +162,7 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G if group_vars is not None: for var in group_vars: group_vars_dict[var].append(group.get(var)) - + # Concatenate all the groups into a single xarray object creating a new coordinate. new_obj = xr.concat(groups_vals, dim=groups_dim).assign_coords({groups_dim: names}) if input_is_dataarray: @@ -158,17 +172,26 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G # If there were extra group variables, then create a Dataset with them if group_vars is not None: - if isinstance(new_obj, DataArray): new_obj = new_obj.to_dataset() - - new_obj = new_obj.assign({ - k: DataArray(v, dims=[groups_dim], name=k) for k,v in group_vars_dict.items() - }) + + new_obj = new_obj.assign( + { + k: DataArray(v, dims=[groups_dim], name=k) + for k, v in group_vars_dict.items() + } + ) return new_obj -def scale_variable(dataset: Dataset, var: str, scale: float = 1, default_value: Union[float, None] = None, allow_not_present: bool = False) -> Dataset: + +def scale_variable( + dataset: Dataset, + var: str, + scale: float = 1, + default_value: Union[float, None] = None, + allow_not_present: bool = False, +) -> Dataset: new = dataset.copy() if var not in new: @@ -184,13 +207,17 @@ def scale_variable(dataset: Dataset, var: str, scale: float = 1, default_value: new[var][new[var] == None] = default_value * scale return new + def select(dataset: Dataset, dim: str, selector: Any) -> Dataset: if selector is not None: dataset = dataset.sel(**{dim: selector}) return dataset + def filter_energy_range( - data: Union[DataArray, Dataset], Erange: Optional[Tuple[float, float]]=None, E0: float = 0 + data: Union[DataArray, Dataset], + Erange: Optional[Tuple[float, float]] = None, + E0: float = 0, ) -> Union[DataArray, Dataset]: # Shift the energies E_data = data.assign_coords(E=data.E - E0) diff --git a/src/sisl/viz/types.py b/src/sisl/viz/types.py index ceaf3027ca..5484dc2e51 100644 --- a/src/sisl/viz/types.py +++ b/src/sisl/viz/types.py @@ -13,29 +13,37 @@ from sisl.lattice import Lattice, LatticeChild from sisl.typing import AtomsArgument -PathLike = Union[str, Path, BaseSile] +PathLike = Union[str, Path, BaseSile] Color = NewType("Color", str) GeometryLike = Union[sisl.Geometry, Any] -Axis = Union[Literal["x", "y", "z", "-x", "-y", "-z", "a", "b", "c", "-a", "-b", "-c"], Sequence[float]] +Axis = Union[ + Literal["x", "y", "z", "-x", "-y", "-z", "a", "b", "c", "-a", "-b", "-c"], + Sequence[float], +] Axes = Sequence[Axis] GeometryLike = Union[Geometry, PathLike] + @dataclass class StyleSpec: color: Optional[Color] = None size: Optional[float] = None opacity: Optional[float] = 1 - dash: Optional[Literal["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"]] = None + dash: Optional[ + Literal["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] + ] = None + @dataclass class AtomsStyleSpec(StyleSpec): atoms: AtomsArgument = None vertices: Optional[float] = 15 + class AtomsStyleSpecDict(TypedDict): atoms: AtomsArgument color: Optional[Color] @@ -43,11 +51,13 @@ class AtomsStyleSpecDict(TypedDict): opacity: Optional[float] vertices: Optional[float] + @dataclass class Query: active: bool = True name: str = "" + Queries = Sequence[Query] SpeciesSpec = NewType("SpeciesSpec", Optional[Sequence[str]]) @@ -55,10 +65,11 @@ class Query: OrbitalsNames = NewType("OrbitalsNames", Optional[Sequence[str]]) SpinIndex = NewType("SpinIndex", Optional[Sequence[Literal[0, 1]]]) + @dataclass class OrbitalQuery(Query): atoms: AtomsArgument = None - species : SpeciesSpec = None + species: SpeciesSpec = None orbitals: OrbitalsNames = None n: Optional[Sequence[int]] = None l: Optional[Sequence[int]] = None @@ -68,36 +79,39 @@ class OrbitalQuery(Query): reduce: Literal["mean", "sum"] = "sum" spin_reduce: Literal["mean", "sum"] = "sum" + @dataclass class OrbitalStyleQuery(StyleSpec, OrbitalQuery): ... - + + OrbitalQueries = Sequence[OrbitalQuery] OrbitalStyleQueries = Sequence[OrbitalStyleQuery] CellLike = Union[npt.NDArray[np.float_], Lattice, LatticeChild] + @dataclass class ArrowSpec: - scale: float = 1. + scale: float = 1.0 color: Any = None - width: float = 1. - opacity: float = 1. + width: float = 1.0 + opacity: float = 1.0 name: str = "arrow" annotate: bool = False arrowhead_scale: float = 0.2 arrowhead_angle: float = 20 + @dataclass class AtomArrowSpec: data: Any atoms: AtomsArgument = None - scale: float = 1. + scale: float = 1.0 color: Any = None - width: float = 1. - opacity: float = 1. + width: float = 1.0 + opacity: float = 1.0 name: str = "arrow" annotate: bool = False arrowhead_scale: float = 0.2 arrowhead_angle: float = 20 - diff --git a/src/sisl_toolbox/btd/_btd.py b/src/sisl_toolbox/btd/_btd.py index 370584d99a..9c087a11fd 100644 --- a/src/sisl_toolbox/btd/_btd.py +++ b/src/sisl_toolbox/btd/_btd.py @@ -56,7 +56,7 @@ def dagger(M): def _scat_state_svd(A, **kwargs): - """ Calculating the SVD of matrix A for the scattering state + """Calculating the SVD of matrix A for the scattering state Parameters ---------- @@ -86,11 +86,12 @@ def _scat_state_svd(A, **kwargs): driver = kwargs.get("driver", "gesvd").lower() if driver in ("arpack", "lobpcg", "sparse"): if driver == "sparse": - driver = "arpack" # scipy default + driver = "arpack" # scipy default # filter out keys for scipy.sparse.svds - svds_kwargs = {key: kwargs[key] for key in ("k", "ncv", "tol", "v0") - if key in kwargs} + svds_kwargs = { + key: kwargs[key] for key in ("k", "ncv", "tol", "v0") if key in kwargs + } # do not calculate vt svds_kwargs["return_singular_vectors"] = "u" svds_kwargs["solver"] = driver @@ -115,11 +116,11 @@ def _scat_state_svd(A, **kwargs): # DOS * # to account for the overlap matrix. For orthogonal basis sets # this DOS eigenvalue is correct. - return DOS ** 2 / (2*np.pi), A + return DOS**2 / (2 * np.pi), A class PivotSelfEnergy(si.physics.SelfEnergy): - """ Container for the self-energy object + """Container for the self-energy object This may either be a `tbtsencSileTBtrans`, a `tbtgfSileTBtrans` or a sisl.SelfEnergy objectfile """ @@ -136,21 +137,28 @@ def __init__(self, name, se, pivot=None): self._se = se if isinstance(se, si.io.tbtrans.tbtsencSileTBtrans): + def se_func(*args, **kwargs): return self._se.self_energy(self.name, *args, **kwargs) + def broad_func(*args, **kwargs): return self._se.broadening_matrix(self.name, *args, **kwargs) + else: + def se_func(*args, **kwargs): return self._se.self_energy(*args, **kwargs) + def broad_func(*args, **kwargs): return self._se.broadening_matrix(*args, **kwargs) # Store the pivoting for faster indexing if pivot is None: if not isinstance(se, si.io.tbtrans.tbtsencSileTBtrans): - raise ValueError(f"{self.__class__.__name__} must be passed a sisl.io.tbtrans.tbtsencSileTBtrans. " - "Otherwise use the DownfoldSelfEnergy method with appropriate arguments.") + raise ValueError( + f"{self.__class__.__name__} must be passed a sisl.io.tbtrans.tbtsencSileTBtrans. " + "Otherwise use the DownfoldSelfEnergy method with appropriate arguments." + ) pivot = se # Pivoting indices for the self-energy for the device region @@ -176,8 +184,8 @@ def broad_func(*args, **kwargs): # collect the pivoting indices for the downfolding pvt_btd.append(self.pvt_down[o:i, 0]) o += i - #self.pvt_btd = np.concatenate(pvt_btd).reshape(-1, 1) - #self.pvt_btd_sort = arangei(o) + # self.pvt_btd = np.concatenate(pvt_btd).reshape(-1, 1) + # self.pvt_btd_sort = arangei(o) self._se_func = se_func self._broad_func = broad_func @@ -196,13 +204,16 @@ def broadening_matrix(self, *args, **kwargs): class DownfoldSelfEnergy(PivotSelfEnergy): - - def __init__(self, name, se, pivot, Hdevice, eta_device=0, bulk=True, bloch=(1, 1, 1)): + def __init__( + self, name, se, pivot, Hdevice, eta_device=0, bulk=True, bloch=(1, 1, 1) + ): super().__init__(name, se, pivot) if np.allclose(bloch, 1): + def _bloch(func, k, *args, **kwargs): return func(*args, k=k, **kwargs) + self._bloch = _bloch else: self._bloch = si.Bloch(bloch) @@ -248,7 +259,7 @@ def _bloch(func, k, *args, **kwargs): self._data.H = PropertyDict() self._data.H.electrode = se.spgeom0 self._data.H.device = Hdevice.sub(down_atoms) - #geometry_down = self._data.H.device.geometry + # geometry_down = self._data.H.device.geometry # Now we retain the positions of the electrode orbitals in the # non pivoted structure for inserting the self-energy @@ -268,7 +279,8 @@ def __str__(self): eta = None try: eta = self._se.eta - except Exception: pass + except Exception: + pass se = str(self._se).replace("\n", "\n ") return f"{self.__class__.__name__}{{no: {len(self)}, blocks: {len(self.btd)}, eta: {eta}, eta_device: {self._eta_device},\n {se}\n}}" @@ -304,12 +316,18 @@ def _prepare(self, E, k=(0, 0, 0)): Ed = data.Ed Eb = data.Eb - data.SeH = H.device.Sk(k, dtype=np.complex128) * Ed - H.device.Hk(k, dtype=np.complex128) + data.SeH = H.device.Sk(k, dtype=np.complex128) * Ed - H.device.Hk( + k, dtype=np.complex128 + ) if data.bulk: + def hsk(k, **kwargs): # constructor for the H and S part return H.electrode.Sk(k, **kwargs) * Eb - H.electrode.Hk(k, **kwargs) - data.SeH[data.elec, data.elec.T] = self._bloch(hsk, k, format="array", dtype=np.complex128) + + data.SeH[data.elec, data.elec.T] = self._bloch( + hsk, k, format="array", dtype=np.complex128 + ) def self_energy(self, E, k=(0, 0, 0), *args, **kwargs): self._prepare(E, k) @@ -330,11 +348,14 @@ def gM(M, idx1, idx2): Mr = 0 sli = slice(cbtd[0], cbtd[1]) for b in range(1, len(self.btd)): - sli1 = slice(cbtd[b], cbtd[b+1]) - - Mr = gM(M, sli1, sli) @ solve(gM(M, sli, sli) - Mr, - gM(M, sli, sli1), - overwrite_a=True, overwrite_b=True) + sli1 = slice(cbtd[b], cbtd[b + 1]) + + Mr = gM(M, sli1, sli) @ solve( + gM(M, sli, sli) - Mr, + gM(M, sli, sli1), + overwrite_a=True, + overwrite_b=True, + ) sli = sli1 return Mr @@ -351,22 +372,28 @@ def __len__(self): return len(self._bm.blocks) def __iter__(self): - """ Loop contained indices in the BlockMatrix """ + """Loop contained indices in the BlockMatrix""" yield from self._bm._M.keys() def __delitem__(self, key): if not isinstance(key, tuple): - raise ValueError(f"{self.__class__.__name__} index deletion must be done with a tuple.") + raise ValueError( + f"{self.__class__.__name__} index deletion must be done with a tuple." + ) del self._bm._M[key] def __contains__(self, key): if not isinstance(key, tuple): - raise ValueError(f"{self.__class__.__name__} index checking must be done with a tuple.") + raise ValueError( + f"{self.__class__.__name__} index checking must be done with a tuple." + ) return key in self._bm._M def __getitem__(self, key): if not isinstance(key, tuple): - raise ValueError(f"{self.__class__.__name__} index retrieval must be done with a tuple.") + raise ValueError( + f"{self.__class__.__name__} index retrieval must be done with a tuple." + ) M = self._bm._M.get(key) if M is None: i, j = key @@ -376,15 +403,19 @@ def __getitem__(self, key): def __setitem__(self, key, M): if not isinstance(key, tuple): - raise ValueError(f"{self.__class__.__name__} index setting must be done with a tuple.") + raise ValueError( + f"{self.__class__.__name__} index setting must be done with a tuple." + ) s = (self._bm.blocks[key[0]], self._bm.blocks[key[1]]) - assert M.shape == s, f"Could not assign matrix of shape {M.shape} into matrix of shape {s}" + assert ( + M.shape == s + ), f"Could not assign matrix of shape {M.shape} into matrix of shape {s}" self._bm._M[key] = M class BlockMatrix: - """ Container class that holds a block matrix """ + """Container class that holds a block matrix""" def __init__(self, blocks): self._M = {} @@ -398,23 +429,24 @@ def toarray(self): BI = self.block_indexer nb = len(BI) # stack stuff together - return np.concatenate([ - np.concatenate([BI[i, j] for i in range(nb)], axis=0) - for j in range(nb)], axis=1) + return np.concatenate( + [np.concatenate([BI[i, j] for i in range(nb)], axis=0) for j in range(nb)], + axis=1, + ) def tobtd(self): - """ Return only the block tridiagonal part of the matrix """ + """Return only the block tridiagonal part of the matrix""" ret = self.__class__(self.blocks) sBI = self.block_indexer rBI = ret.block_indexer nb = len(sBI) for j in range(nb): - for i in range(max(0, j-1), min(j+2, nb)): + for i in range(max(0, j - 1), min(j + 2, nb)): rBI[i, j] = sBI[i, j] return ret def tobd(self): - """ Return only the block diagonal part of the matrix """ + """Return only the block diagonal part of the matrix""" ret = self.__class__(self.blocks) sBI = self.block_indexer rBI = ret.block_indexer @@ -433,7 +465,7 @@ def block_indexer(self): class DeviceGreen: - r""" Block-tri-diagonal Green function calculator + r"""Block-tri-diagonal Green function calculator This class enables the extraction and calculation of some important quantities not currently accessible in TBtrans. @@ -499,17 +531,18 @@ class DeviceGreen: # That would probably require us to use a method to retrieve # the elements which determines if it has been calculated or not. - def __init__(self, H, elecs, pivot, eta=0.): - """ Create Green function with Hamiltonian and BTD matrix elements """ + def __init__(self, H, elecs, pivot, eta=0.0): + """Create Green function with Hamiltonian and BTD matrix elements""" self.H = H # Store electrodes (for easy retrieval of the SE) # There may be no electrodes self.elecs = elecs - #self.elecs_pvt = [pivot.pivot(el.name).reshape(-1, 1) + # self.elecs_pvt = [pivot.pivot(el.name).reshape(-1, 1) # for el in elecs] - self.elecs_pvt_dev = [pivot.pivot(el.name, in_device=True).reshape(-1, 1) - for el in elecs] + self.elecs_pvt_dev = [ + pivot.pivot(el.name, in_device=True).reshape(-1, 1) for el in elecs + ] self.pvt = pivot.pivot() self.btd = pivot.btd() @@ -532,7 +565,7 @@ def __str__(self): @classmethod def from_fdf(cls, fdf, prefix="TBT", use_tbt_se=False, eta=None): - """ Return a new `DeviceGreen` using information gathered from the fdf + """Return a new `DeviceGreen` using information gathered from the fdf Parameters ---------- @@ -557,8 +590,10 @@ def from_fdf(cls, fdf, prefix="TBT", use_tbt_se=False, eta=None): if Path(f"{slabel}.{end}").is_file(): tbt = f"{slabel}.{end}" if tbt is None: - raise FileNotFoundError(f"{cls.__name__}.from_fdf could " - f"not find file {slabel}.[TBT|TBT_UP|TBT_DN].nc") + raise FileNotFoundError( + f"{cls.__name__}.from_fdf could " + f"not find file {slabel}.[TBT|TBT_UP|TBT_DN].nc" + ) tbt = si.get_sile(tbt) is_tbtrans = prefix.upper() == "TBT" @@ -566,23 +601,27 @@ def from_fdf(cls, fdf, prefix="TBT", use_tbt_se=False, eta=None): Hdev = si.get_sile(fdf.get("TBT.HS", f"{slabel}.TSHS")).read_hamiltonian() def get_line(line): - """ Parse lines in the %block constructs of fdf's """ + """Parse lines in the %block constructs of fdf's""" key, val = line.split(" ", 1) return key.lower().strip(), val.split("#", 1)[0].strip() def read_electrode(elec_prefix): - """ Parse the electrode information and return a dictionary with content """ + """Parse the electrode information and return a dictionary with content""" from sisl.unit.siesta import unit_convert + ret = PropertyDict() if is_tbtrans: + def block_get(dic, key, default=None, unit=None): ret = dic.get(f"tbt.{key}", dic.get(key, default)) if unit is None or not isinstance(ret, str): return ret ret, un = ret.split() return float(ret) * unit_convert(un, unit) + else: + def block_get(dic, key, default=None, unit=None): ret = dic.get(key, default) if unit is None or not isinstance(ret, str): @@ -617,23 +656,31 @@ def block_get(dic, key, default=None, unit=None): if Helec: Helec = si.get_sile(Helec).read_hamiltonian() else: - raise ValueError(f"{self.__class__.__name__}.from_fdf could not find " - f"electrode HS in block: {prefix} ??") + raise ValueError( + f"{self.__class__.__name__}.from_fdf could not find " + f"electrode HS in block: {prefix} ??" + ) # Get semi-infinite direction semi_inf = None for suf in ["-direction", "-dir", ""]: semi_inf = block_get(dic, f"semi-inf{suf}", semi_inf) if semi_inf is None: - raise ValueError(f"{self.__class__.__name__}.from_fdf could not find " - f"electrode semi-inf-direction in block: {prefix} ??") + raise ValueError( + f"{self.__class__.__name__}.from_fdf could not find " + f"electrode semi-inf-direction in block: {prefix} ??" + ) # convert to sisl infinite semi_inf = semi_inf.lower() - semi_inf = semi_inf[0] + {"a1": "a", "a2": "b", "a3": "c"}.get(semi_inf[1:], semi_inf[1:]) + semi_inf = semi_inf[0] + {"a1": "a", "a2": "b", "a3": "c"}.get( + semi_inf[1:], semi_inf[1:] + ) # Check that semi_inf is a recursive one! if not semi_inf in ["-a", "+a", "-b", "+b", "-c", "+c"]: - raise NotImplementedError(f"{self.__class__.__name__} does not implement other " - "self energies than the recursive one.") + raise NotImplementedError( + f"{self.__class__.__name__} does not implement other " + "self energies than the recursive one." + ) bulk = bool(block_get(dic, "bulk", bulk)) # loop for 0 @@ -641,21 +688,28 @@ def block_get(dic, key, default=None, unit=None): for suf in sufs: bloch[i] = block_get(dic, f"bloch-{suf}", bloch[i]) - bloch = [int(b) for b in block_get(dic, "bloch", f"{bloch[0]} {bloch[1]} {bloch[2]}").split()] + bloch = [ + int(b) + for b in block_get( + dic, "bloch", f"{bloch[0]} {bloch[1]} {bloch[2]}" + ).split() + ] ret.eta = block_get(dic, "eta", eta, unit="eV") # manual shift of the fermi-level - dEf = block_get(dic, "delta-Ef", 0., unit="eV") + dEf = block_get(dic, "delta-Ef", 0.0, unit="eV") # shift electronic structure here, we store it in the returned # dictionary, for information, but it shouldn't be used. Helec.shift(dEf) ret.dEf = dEf # add a fraction of the bias in the coupling elements of the # E-C region, only meaningful for - ret.V_fraction = block_get(dic, "V-fraction", 0.) - if ret.V_fraction > 0.: - warn(f"{cls.__name__}.from_fdf(electrode={elec}) found a non-zero V-fraction value. " - "This is currently not implemented.") + ret.V_fraction = block_get(dic, "V-fraction", 0.0) + if ret.V_fraction > 0.0: + warn( + f"{cls.__name__}.from_fdf(electrode={elec}) found a non-zero V-fraction value. " + "This is currently not implemented." + ) ret.Helec = Helec ret.bloch = bloch ret.semi_inf = semi_inf @@ -681,15 +735,19 @@ def block_get(dic, key, default=None, unit=None): # read from the TBT file (to check if the user has changed the input file) elec_eta = tbt.eta(elec) if not np.allclose(elec_eta, data.eta): - warn(f"{cls.__name__}.from_fdf(electrode={elec}) found inconsistent " - f"imaginary eta from the fdf vs. TBT output, will use fdf value.\n" - f" {tbt} = {eta} eV\n {fdf} = {data.eta} eV") + warn( + f"{cls.__name__}.from_fdf(electrode={elec}) found inconsistent " + f"imaginary eta from the fdf vs. TBT output, will use fdf value.\n" + f" {tbt} = {eta} eV\n {fdf} = {data.eta} eV" + ) bloch = tbt.bloch(elec) if not np.allclose(bloch, data.bloch): - warn(f"{cls.__name__}.from_fdf(electrode={elec}) found inconsistent " - f"Bloch expansions from the fdf vs. TBT output, will use fdf value.\n" - f" {tbt} = {bloch}\n {fdf} = {data.bloch}") + warn( + f"{cls.__name__}.from_fdf(electrode={elec}) found inconsistent " + f"Bloch expansions from the fdf vs. TBT output, will use fdf value.\n" + f" {tbt} = {bloch}\n {fdf} = {data.bloch}" + ) eta_dev = min(data.eta, eta_dev) @@ -710,9 +768,11 @@ def block_get(dic, key, default=None, unit=None): eta_dev = eta elif not np.allclose(eta_dev, eta_dev_tbt): - warn(f"{cls.__name__}.from_fdf found inconsistent " - f"imaginary eta from the fdf vs. TBT output, will use fdf value.\n" - f" {tbt} = {eta_dev_tbt} eV\n {fdf} = {eta_dev} eV") + warn( + f"{cls.__name__}.from_fdf found inconsistent " + f"imaginary eta from the fdf vs. TBT output, will use fdf value.\n" + f" {tbt} = {eta_dev_tbt} eV\n {fdf} = {eta_dev} eV" + ) elecs = [] for elec in tbt.elecs: @@ -723,9 +783,11 @@ def block_get(dic, key, default=None, unit=None): if Path(f"{slabel}.TBT.SE.nc").is_file(): tbtse = si.get_sile(f"{slabel}.TBT.SE.nc") else: - raise FileNotFoundError(f"{cls.__name__}.from_fdf " - f"could not find file {slabel}.TBT.SE.nc " - "but it was requested by 'use_tbt_se'!") + raise FileNotFoundError( + f"{cls.__name__}.from_fdf " + f"could not find file {slabel}.TBT.SE.nc " + "but it was requested by 'use_tbt_se'!" + ) # shift according to potential data.Helec.shift(mu) @@ -741,23 +803,29 @@ def block_get(dic, key, default=None, unit=None): if elec in use_tbt_se: elec_se = PivotSelfEnergy(elec, tbtse) else: - elec_se = DownfoldSelfEnergy(elec, se, tbt, Hdev, - eta_device=eta_dev, - bulk=data.bulk, bloch=data.bloch) + elec_se = DownfoldSelfEnergy( + elec, + se, + tbt, + Hdev, + eta_device=eta_dev, + bulk=data.bulk, + bloch=data.bloch, + ) elecs.append(elec_se) return cls(Hdev, elecs, tbt, eta_dev) def reset(self): - """ Clean any memory used by this object """ + """Clean any memory used by this object""" self._data = PropertyDict() def __len__(self): return len(self.pvt) def _elec(self, elec): - """ Convert a string electrode to the proper linear index """ + """Convert a string electrode to the proper linear index""" if isinstance(elec, str): for iel, el in enumerate(self.elecs): if el.name == elec: @@ -767,7 +835,7 @@ def _elec(self, elec): return elec def _elec_name(self, elec): - """ Convert an electrode index or str to the name of the electrode """ + """Convert an electrode index or str to the name of the electrode""" if isinstance(elec, str): return elec elif isinstance(elec, PivotSelfEnergy): @@ -842,7 +910,9 @@ def _prepare(self, E, k): k = data.k # Prepare the Green function calculation - inv_G = self.H.Sk(k, dtype=np.complex128) * Ec - self.H.Hk(k, dtype=np.complex128) + inv_G = self.H.Sk(k, dtype=np.complex128) * Ec - self.H.Hk( + k, dtype=np.complex128 + ) # Now reduce the sparse matrix to the device region (plus do the pivoting) inv_G = inv_G[self.pvt, :][:, self.pvt] @@ -878,16 +948,16 @@ def _prepare(self, E, k): # rotate slices sln = sl0 sl0 = slp - slp = slice(cbtd[b+1], cbtd[b+2]) + slp = slice(cbtd[b + 1], cbtd[b + 2]) iG = inv_G[sl0, :].tocsc() - B[b-1] = iG[:, sln].toarray() + B[b - 1] = iG[:, sln].toarray() A[b] = iG[:, sl0].toarray() - C[b+1] = iG[:, slp].toarray() + C[b + 1] = iG[:, slp].toarray() # and final matrix A and B iG = inv_G[slp, :].tocsc() A[nbm1] = iG[:, slp].toarray() - B[nbm1-1] = iG[:, sl0].toarray() + B[nbm1 - 1] = iG[:, sl0].toarray() # clean-up, not used anymore del inv_G @@ -906,9 +976,9 @@ def _prepare(self, E, k): for n in range(2, nb): p = nb - n - 1 # \tilde Y - tY[n] = solve(A[n-1] - B[n-2] @ tY[n-1], C[n], overwrite_a=True) + tY[n] = solve(A[n - 1] - B[n - 2] @ tY[n - 1], C[n], overwrite_a=True) # \tilde X - tX[p] = solve(A[p+1] - C[p+2] @ tX[p+1], B[p], overwrite_a=True) + tX[p] = solve(A[p + 1] - C[p + 2] @ tX[p + 1], B[p], overwrite_a=True) data.tX = tX data.tY = tY @@ -920,12 +990,12 @@ def _matrix_to_btd(self, M): nb = len(BI) if ssp.issparse(M): for jb in range(nb): - for ib in range(max(0, jb-1), min(jb+2, nb)): - BI[ib, jb] = M[c[ib]:c[ib+1], c[jb]:c[jb+1]].toarray() + for ib in range(max(0, jb - 1), min(jb + 2, nb)): + BI[ib, jb] = M[c[ib] : c[ib + 1], c[jb] : c[jb + 1]].toarray() else: for jb in range(nb): - for ib in range(max(0, jb-1), min(jb+2, nb)): - BI[ib, jb] = M[c[ib]:c[ib+1], c[jb]:c[jb+1]] + for ib in range(max(0, jb - 1), min(jb + 2, nb)): + BI[ib, jb] = M[c[ib] : c[ib + 1], c[jb] : c[jb + 1]] return BM def Sk(self, *args, **kwargs): @@ -961,11 +1031,11 @@ def _get_blocks(self, idx): if block1 == block2: blocks = [block1] else: - blocks = [b for b in range(block1, block2+1)] + blocks = [b for b in range(block1, block2 + 1)] return blocks def green(self, E, k=(0, 0, 0), format="array"): - r""" Calculate the Green function for a given `E` and `k` point + r"""Calculate the Green function for a given `E` and `k` point The Green function is calculated as: @@ -994,7 +1064,9 @@ def green(self, E, k=(0, 0, 0), format="array"): format = "array" func = getattr(self, f"_green_{format}", None) if func is None: - raise ValueError(f"{self.__class__.__name__}.green format not valid input [array|sparse|bm|btd|bd]") + raise ValueError( + f"{self.__class__.__name__}.green format not valid input [array|sparse|bm|btd|bd]" + ) return func() def _green_array(self): @@ -1027,7 +1099,7 @@ def _green_array(self): for a in range(b - 1, -1, -1): # Calculate all parts above sla = slice(next_sum - btd[a], next_sum) - G[sla, sl0] = - tY[a + 1] @ G[slp, sl0] + G[sla, sl0] = -tY[a + 1] @ G[slp, sl0] slp = sla next_sum -= btd[a] @@ -1042,7 +1114,7 @@ def _green_array(self): for a in range(b + 1, nb): # Calculate all parts above sla = slice(next_sum, next_sum + btd[a]) - G[sla, sl0] = - tX[a - 1] @ G[slp, sl0] + G[sla, sl0] = -tX[a - 1] @ G[slp, sl0] slp = sla next_sum += btd[a] @@ -1070,10 +1142,10 @@ def _green_btd(self): BI[b, b] = G11 # do above if b > 0: - BI[b - 1, b] = - tY[b] @ G11 + BI[b - 1, b] = -tY[b] @ G11 # do below if b < nbm1: - BI[b + 1, b] = - tX[b] @ G11 + BI[b + 1, b] = -tX[b] @ G11 return G @@ -1088,12 +1160,12 @@ def _green_bm(self): for b in range(nb): G0 = BI[b, b] for bb in range(b, 0, -1): - G0 = - tY[bb] @ G0 - BI[bb-1, b] = G0 + G0 = -tY[bb] @ G0 + BI[bb - 1, b] = G0 G0 = BI[b, b] for bb in range(b, nbm1): - G0 = - tX[bb] @ G0 - BI[bb+1, b] = G0 + G0 = -tX[bb] @ G0 + BI[bb + 1, b] = G0 return G @@ -1187,9 +1259,13 @@ def _green_diag_block(self, idx): # Find parts we need to calculate blocks = self._get_blocks(idx) - assert len(blocks) <= 2, f"{self.__class__.__name__} green(diagonal) requires maximally 2 blocks" + assert ( + len(blocks) <= 2 + ), f"{self.__class__.__name__} green(diagonal) requires maximally 2 blocks" if len(blocks) == 2: - assert blocks[0]+1 == blocks[1], f"{self.__class__.__name__} green(diagonal) requires spanning only 2 blocks" + assert ( + blocks[0] + 1 == blocks[1] + ), f"{self.__class__.__name__} green(diagonal) requires spanning only 2 blocks" n = self.btd[blocks].sum() G = np.empty([n, len(idx)], dtype=self._data.A[0].dtype) @@ -1220,7 +1296,9 @@ def _green_diag_block(self, idx): elif b == nbm1: G[sl, c_idx] = inv_destroy(A[b] - B[b - 1] @ tY[b])[:, b_idx] else: - G[sl, c_idx] = inv_destroy(A[b] - B[b - 1] @ tY[b] - C[b + 1] @ tX[b])[:, b_idx] + G[sl, c_idx] = inv_destroy(A[b] - B[b - 1] @ tY[b] - C[b + 1] @ tX[b])[ + :, b_idx + ] if len(blocks) == 1: break @@ -1229,11 +1307,11 @@ def _green_diag_block(self, idx): if b == blocks[0]: # Calculate below slp = slice(btd[b], btd[b] + btd[blocks[1]]) - G[slp, c_idx] = - tX[b] @ G[sl, c_idx] + G[slp, c_idx] = -tX[b] @ G[sl, c_idx] else: # Calculate above slp = slice(0, btd[blocks[0]]) - G[slp, c_idx] = - tY[b] @ G[sl, c_idx] + G[slp, c_idx] = -tY[b] @ G[sl, c_idx] return blocks, G @@ -1245,9 +1323,13 @@ def _green_column(self, idx): # Find parts we need to calculate blocks = self._get_blocks(idx) - assert len(blocks) <= 2, f"{self.__class__.__name__}.green(column) requires maximally 2 blocks" + assert ( + len(blocks) <= 2 + ), f"{self.__class__.__name__}.green(column) requires maximally 2 blocks" if len(blocks) == 2: - assert blocks[0]+1 == blocks[1], f"{self.__class__.__name__}.green(column) requires spanning only 2 blocks" + assert ( + blocks[0] + 1 == blocks[1] + ), f"{self.__class__.__name__}.green(column) requires spanning only 2 blocks" n = len(self) G = np.empty([n, len(idx)], dtype=self._data.A[0].dtype) @@ -1276,7 +1358,9 @@ def _green_column(self, idx): elif b == nbm1: G[sl, c_idx] = inv_destroy(A[b] - B[b - 1] @ tY[b])[:, b_idx] else: - G[sl, c_idx] = inv_destroy(A[b] - B[b - 1] @ tY[b] - C[b + 1] @ tX[b])[:, b_idx] + G[sl, c_idx] = inv_destroy(A[b] - B[b - 1] @ tY[b] - C[b + 1] @ tX[b])[ + :, b_idx + ] if len(blocks) == 1: break @@ -1287,18 +1371,18 @@ def _green_column(self, idx): if b == blocks[0] and b < nb - 1: # Calculate below slp = slice(c[b + 1], c[b + 2]) - G[slp, c_idx] = - tX[b] @ G[sl, c_idx] + G[slp, c_idx] = -tX[b] @ G[sl, c_idx] elif b > 0: # Calculate above slp = slice(c[b - 1], c[b]) - G[slp, c_idx] = - tY[b] @ G[sl, c_idx] + G[slp, c_idx] = -tY[b] @ G[sl, c_idx] # Now we can calculate the Gf column above b = blocks[0] slp = slice(c[b], c[b + 1]) for b in range(blocks[0] - 1, -1, -1): sl = slice(c[b], c[b + 1]) - G[sl, :] = - tY[b + 1] @ G[slp, :] + G[sl, :] = -tY[b + 1] @ G[slp, :] slp = sl # All blocks below @@ -1306,13 +1390,15 @@ def _green_column(self, idx): slp = slice(c[b], c[b + 1]) for b in range(blocks[-1] + 1, nb): sl = slice(c[b], c[b + 1]) - G[sl, :] = - tX[b - 1] @ G[slp, :] + G[sl, :] = -tX[b - 1] @ G[slp, :] slp = sl return G - def spectral(self, elec, E, k=(0, 0, 0), format="array", method="column", herm=True): - r""" Calculate the spectral function for a given `E` and `k` point from a given electrode + def spectral( + self, elec, E, k=(0, 0, 0), format="array", method="column", herm=True + ): + r"""Calculate the spectral function for a given `E` and `k` point from a given electrode The spectral function is calculated as: @@ -1356,7 +1442,9 @@ def spectral(self, elec, E, k=(0, 0, 0), format="array", method="column", herm=T format = "btd" func = getattr(self, f"_spectral_{method}_{format}", None) if func is None: - raise ValueError(f"{self.__class__.__name__}.spectral combination of format+method not recognized {format}+{method}.") + raise ValueError( + f"{self.__class__.__name__}.spectral combination of format+method not recognized {format}+{method}." + ) return func(elec, herm) def _spectral_column_array(self, elec, herm): @@ -1379,10 +1467,10 @@ def _spectral_column_bm(self, elec, herm): if herm: # loop columns for jb in range(nb): - slj = slice(c[jb], c[jb+1]) + slj = slice(c[jb], c[jb + 1]) Gj = Gam @ dagger(G[slj, :]) for ib in range(jb): - sli = slice(c[ib], c[ib+1]) + sli = slice(c[ib], c[ib + 1]) BI[ib, jb] = G[sli, :] @ Gj BI[jb, ib] = BI[ib, jb].T.conj() BI[jb, jb] = G[slj, :] @ Gj @@ -1390,10 +1478,10 @@ def _spectral_column_bm(self, elec, herm): else: # loop columns for jb in range(nb): - slj = slice(c[jb], c[jb+1]) + slj = slice(c[jb], c[jb + 1]) Gj = Gam @ dagger(G[slj, :]) for ib in range(nb): - sli = slice(c[ib], c[ib+1]) + sli = slice(c[ib], c[ib + 1]) BI[ib, jb] = G[sli, :] @ Gj return btd @@ -1412,10 +1500,10 @@ def _spectral_column_btd(self, elec, herm): if herm: # loop columns for jb in range(nb): - slj = slice(c[jb], c[jb+1]) + slj = slice(c[jb], c[jb + 1]) Gj = Gam @ dagger(G[slj, :]) for ib in range(max(0, jb - 1), jb): - sli = slice(c[ib], c[ib+1]) + sli = slice(c[ib], c[ib + 1]) BI[ib, jb] = G[sli, :] @ Gj BI[jb, ib] = BI[ib, jb].T.conj() BI[jb, jb] = G[slj, :] @ Gj @@ -1423,10 +1511,10 @@ def _spectral_column_btd(self, elec, herm): else: # loop columns for jb in range(nb): - slj = slice(c[jb], c[jb+1]) + slj = slice(c[jb], c[jb + 1]) Gj = Gam @ dagger(G[slj, :]) - for ib in range(max(0, jb-1), min(jb+2, nb)): - sli = slice(c[ib], c[ib+1]) + for ib in range(max(0, jb - 1), min(jb + 2, nb)): + sli = slice(c[ib], c[ib + 1]) BI[ib, jb] = G[sli, :] @ Gj return btd @@ -1444,7 +1532,7 @@ def _spectral_propagate_array(self, elec, herm): S = np.empty([len(self), len(self)], dtype=A.dtype) c = self.btd_cum0 - S[c[blocks[0]]:c[blocks[-1]+1], c[blocks[0]]:c[blocks[-1]+1]] = A + S[c[blocks[0]] : c[blocks[-1] + 1], c[blocks[0]] : c[blocks[-1] + 1]] = A del A # now loop backwards @@ -1452,73 +1540,77 @@ def _spectral_propagate_array(self, elec, herm): tY = self._data.tY def gs(ib, jb): - return slice(c[ib], c[ib+1]), slice(c[jb], c[jb+1]) + return slice(c[ib], c[ib + 1]), slice(c[jb], c[jb + 1]) if herm: # above left for jb in range(blocks[0], -1, -1): for ib in range(jb, 0, -1): - A = - tY[ib] @ S[gs(ib, jb)] - S[gs(ib-1, jb)] = A - S[gs(jb, ib-1)] = A.T.conj() + A = -tY[ib] @ S[gs(ib, jb)] + S[gs(ib - 1, jb)] = A + S[gs(jb, ib - 1)] = A.T.conj() # calculate next diagonal if jb > 0: - S[gs(jb-1, jb-1)] = - S[gs(jb-1, jb)] @ dagger(tY[jb]) + S[gs(jb - 1, jb - 1)] = -S[gs(jb - 1, jb)] @ dagger(tY[jb]) if nblocks == 2: # above for ib in range(blocks[1], 1, -1): - A = - tY[ib-1] @ S[gs(ib-1, blocks[1])] - S[gs(ib-2, blocks[1])] = A - S[gs(blocks[1], ib-2)] = A.T.conj() + A = -tY[ib - 1] @ S[gs(ib - 1, blocks[1])] + S[gs(ib - 2, blocks[1])] = A + S[gs(blocks[1], ib - 2)] = A.T.conj() # below - for ib in range(blocks[0], nbm1-1): - A = - tX[ib+1] @ S[gs(ib+1, blocks[0])] - S[gs(ib+2, blocks[0])] = A - S[gs(blocks[0], ib+2)] = A.T.conj() + for ib in range(blocks[0], nbm1 - 1): + A = -tX[ib + 1] @ S[gs(ib + 1, blocks[0])] + S[gs(ib + 2, blocks[0])] = A + S[gs(blocks[0], ib + 2)] = A.T.conj() # below right for jb in range(blocks[-1], nb): for ib in range(jb, nbm1): - A = - tX[ib] @ S[gs(ib, jb)] - S[gs(ib+1, jb)] = A - S[gs(jb, ib+1)] = A.T.conj() + A = -tX[ib] @ S[gs(ib, jb)] + S[gs(ib + 1, jb)] = A + S[gs(jb, ib + 1)] = A.T.conj() # calculate next diagonal if jb < nbm1: - S[gs(jb+1, jb+1)] = - S[gs(jb+1, jb)] @ dagger(tX[jb]) + S[gs(jb + 1, jb + 1)] = -S[gs(jb + 1, jb)] @ dagger(tX[jb]) else: for jb in range(blocks[0], -1, -1): # above for ib in range(jb, 0, -1): - S[gs(ib-1, jb)] = - tY[ib] @ S[gs(ib, jb)] + S[gs(ib - 1, jb)] = -tY[ib] @ S[gs(ib, jb)] # calculate next diagonal if jb > 0: - S[gs(jb-1, jb-1)] = - S[gs(jb-1, jb)] @ dagger(tY[jb]) + S[gs(jb - 1, jb - 1)] = -S[gs(jb - 1, jb)] @ dagger(tY[jb]) # left for ib in range(jb, 0, -1): - S[gs(jb, ib-1)] = - S[gs(jb, ib)] @ dagger(tY[ib]) + S[gs(jb, ib - 1)] = -S[gs(jb, ib)] @ dagger(tY[ib]) if nblocks == 2: # above and left for ib in range(blocks[1], 1, -1): - S[gs(ib-2, blocks[1])] = - tY[ib-1] @ S[gs(ib-1, blocks[1])] - S[gs(blocks[1], ib-2)] = - S[gs(blocks[1], ib-1)] @ dagger(tY[ib-1]) + S[gs(ib - 2, blocks[1])] = -tY[ib - 1] @ S[gs(ib - 1, blocks[1])] + S[gs(blocks[1], ib - 2)] = -S[gs(blocks[1], ib - 1)] @ dagger( + tY[ib - 1] + ) # below and right - for ib in range(blocks[0], nbm1-1): - S[gs(ib+2, blocks[0])] = - tX[ib+1] @ S[gs(ib+1, blocks[0])] - S[gs(blocks[0], ib+2)] = - S[gs(blocks[0], ib+1)] @ dagger(tX[ib+1]) + for ib in range(blocks[0], nbm1 - 1): + S[gs(ib + 2, blocks[0])] = -tX[ib + 1] @ S[gs(ib + 1, blocks[0])] + S[gs(blocks[0], ib + 2)] = -S[gs(blocks[0], ib + 1)] @ dagger( + tX[ib + 1] + ) # below right for jb in range(blocks[-1], nb): for ib in range(jb, nbm1): - S[gs(ib+1, jb)] = - tX[ib] @ S[gs(ib, jb)] + S[gs(ib + 1, jb)] = -tX[ib] @ S[gs(ib, jb)] # calculate next diagonal if jb < nbm1: - S[gs(jb+1, jb+1)] = - S[gs(jb+1, jb)] @ dagger(tX[jb]) + S[gs(jb + 1, jb + 1)] = -S[gs(jb + 1, jb)] @ dagger(tX[jb]) # right for ib in range(jb, nbm1): - S[gs(jb, ib+1)] = - S[gs(jb, ib)] @ dagger(tX[ib]) + S[gs(jb, ib + 1)] = -S[gs(jb, ib)] @ dagger(tX[ib]) return S @@ -1535,11 +1627,11 @@ def _spectral_propagate_bm(self, elec, herm): nblocks = len(blocks) A = A @ self._data.gamma[elec] @ dagger(A) - BI[blocks[0], blocks[0]] = A[:btd[blocks[0]], :btd[blocks[0]]] + BI[blocks[0], blocks[0]] = A[: btd[blocks[0]], : btd[blocks[0]]] if len(blocks) > 1: - BI[blocks[0], blocks[1]] = A[:btd[blocks[0]], btd[blocks[0]]:] - BI[blocks[1], blocks[0]] = A[btd[blocks[0]]:, :btd[blocks[0]]] - BI[blocks[1], blocks[1]] = A[btd[blocks[0]]:, btd[blocks[0]]:] + BI[blocks[0], blocks[1]] = A[: btd[blocks[0]], btd[blocks[0]] :] + BI[blocks[1], blocks[0]] = A[btd[blocks[0]] :, : btd[blocks[0]]] + BI[blocks[1], blocks[1]] = A[btd[blocks[0]] :, btd[blocks[0]] :] # now loop backwards tX = self._data.tX @@ -1549,67 +1641,67 @@ def _spectral_propagate_bm(self, elec, herm): # above left for jb in range(blocks[0], -1, -1): for ib in range(jb, 0, -1): - A = - tY[ib] @ BI[ib, jb] - BI[ib-1, jb] = A - BI[jb, ib-1] = A.T.conj() + A = -tY[ib] @ BI[ib, jb] + BI[ib - 1, jb] = A + BI[jb, ib - 1] = A.T.conj() # calculate next diagonal if jb > 0: - BI[jb-1, jb-1] = - BI[jb-1, jb] @ dagger(tY[jb]) + BI[jb - 1, jb - 1] = -BI[jb - 1, jb] @ dagger(tY[jb]) if nblocks == 2: # above for ib in range(blocks[1], 1, -1): - A = - tY[ib-1] @ BI[ib-1, blocks[1]] - BI[ib-2, blocks[1]] = A - BI[blocks[1], ib-2] = A.T.conj() + A = -tY[ib - 1] @ BI[ib - 1, blocks[1]] + BI[ib - 2, blocks[1]] = A + BI[blocks[1], ib - 2] = A.T.conj() # below - for ib in range(blocks[0], nbm1-1): - A = - tX[ib+1] @ BI[ib+1, blocks[0]] - BI[ib+2, blocks[0]] = A - BI[blocks[0], ib+2] = A.T.conj() + for ib in range(blocks[0], nbm1 - 1): + A = -tX[ib + 1] @ BI[ib + 1, blocks[0]] + BI[ib + 2, blocks[0]] = A + BI[blocks[0], ib + 2] = A.T.conj() # below right for jb in range(blocks[-1], nb): for ib in range(jb, nbm1): - A = - tX[ib] @ BI[ib, jb] - BI[ib+1, jb] = A - BI[jb, ib+1] = A.T.conj() + A = -tX[ib] @ BI[ib, jb] + BI[ib + 1, jb] = A + BI[jb, ib + 1] = A.T.conj() # calculate next diagonal if jb < nbm1: - BI[jb+1, jb+1] = - BI[jb+1, jb] @ dagger(tX[jb]) + BI[jb + 1, jb + 1] = -BI[jb + 1, jb] @ dagger(tX[jb]) else: for jb in range(blocks[0], -1, -1): # above for ib in range(jb, 0, -1): - BI[ib-1, jb] = - tY[ib] @ BI[ib, jb] + BI[ib - 1, jb] = -tY[ib] @ BI[ib, jb] # calculate next diagonal if jb > 0: - BI[jb-1, jb-1] = - BI[jb-1, jb] @ dagger(tY[jb]) + BI[jb - 1, jb - 1] = -BI[jb - 1, jb] @ dagger(tY[jb]) # left for ib in range(jb, 0, -1): - BI[jb, ib-1] = - BI[jb, ib] @ dagger(tY[ib]) + BI[jb, ib - 1] = -BI[jb, ib] @ dagger(tY[ib]) if nblocks == 2: # above and left for ib in range(blocks[1], 1, -1): - BI[ib-2, blocks[1]] = - tY[ib-1] @ BI[ib-1, blocks[1]] - BI[blocks[1], ib-2] = - BI[blocks[1], ib-1] @ dagger(tY[ib-1]) + BI[ib - 2, blocks[1]] = -tY[ib - 1] @ BI[ib - 1, blocks[1]] + BI[blocks[1], ib - 2] = -BI[blocks[1], ib - 1] @ dagger(tY[ib - 1]) # below and right - for ib in range(blocks[0], nbm1-1): - BI[ib+2, blocks[0]] = - tX[ib+1] @ BI[ib+1, blocks[0]] - BI[blocks[0], ib+2] = - BI[blocks[0], ib+1] @ dagger(tX[ib+1]) + for ib in range(blocks[0], nbm1 - 1): + BI[ib + 2, blocks[0]] = -tX[ib + 1] @ BI[ib + 1, blocks[0]] + BI[blocks[0], ib + 2] = -BI[blocks[0], ib + 1] @ dagger(tX[ib + 1]) # below right for jb in range(blocks[-1], nb): for ib in range(jb, nbm1): - BI[ib+1, jb] = - tX[ib] @ BI[ib, jb] + BI[ib + 1, jb] = -tX[ib] @ BI[ib, jb] # calculate next diagonal if jb < nbm1: - BI[jb+1, jb+1] = - BI[jb+1, jb] @ dagger(tX[jb]) + BI[jb + 1, jb + 1] = -BI[jb + 1, jb] @ dagger(tX[jb]) # right for ib in range(jb, nbm1): - BI[jb, ib+1] = - BI[jb, ib] @ dagger(tX[ib]) + BI[jb, ib + 1] = -BI[jb, ib] @ dagger(tX[ib]) return BM @@ -1625,11 +1717,11 @@ def _spectral_propagate_btd(self, elec, herm): blocks, A = self._green_diag_block(self.elecs_pvt_dev[elec].ravel()) A = A @ self._data.gamma[elec] @ dagger(A) - BI[blocks[0], blocks[0]] = A[:btd[blocks[0]], :btd[blocks[0]]] + BI[blocks[0], blocks[0]] = A[: btd[blocks[0]], : btd[blocks[0]]] if len(blocks) > 1: - BI[blocks[0], blocks[1]] = A[:btd[blocks[0]], btd[blocks[0]]:] - BI[blocks[1], blocks[0]] = A[btd[blocks[0]]:, :btd[blocks[0]]] - BI[blocks[1], blocks[1]] = A[btd[blocks[0]]:, btd[blocks[0]]:] + BI[blocks[0], blocks[1]] = A[: btd[blocks[0]], btd[blocks[0]] :] + BI[blocks[1], blocks[0]] = A[btd[blocks[0]] :, : btd[blocks[0]]] + BI[blocks[1], blocks[1]] = A[btd[blocks[0]] :, btd[blocks[0]] :] # now loop backwards tX = self._data.tX @@ -1638,36 +1730,36 @@ def _spectral_propagate_btd(self, elec, herm): if herm: # above for b in range(blocks[0], 0, -1): - A = - tY[b] @ BI[b, b] - BI[b-1, b] = A - BI[b-1, b-1] = - A @ dagger(tY[b]) - BI[b, b-1] = A.T.conj() + A = -tY[b] @ BI[b, b] + BI[b - 1, b] = A + BI[b - 1, b - 1] = -A @ dagger(tY[b]) + BI[b, b - 1] = A.T.conj() # right for b in range(blocks[-1], nbm1): - A = - BI[b, b] @ dagger(tX[b]) - BI[b, b+1] = A - BI[b+1, b+1] = - tX[b] @ A - BI[b+1, b] = A.T.conj() + A = -BI[b, b] @ dagger(tX[b]) + BI[b, b + 1] = A + BI[b + 1, b + 1] = -tX[b] @ A + BI[b + 1, b] = A.T.conj() else: # above for b in range(blocks[0], 0, -1): dtY = dagger(tY[b]) - A = - tY[b] @ BI[b, b] - BI[b-1, b] = A - BI[b-1, b-1] = - A @ dtY - BI[b, b-1] = - BI[b, b] @ dtY + A = -tY[b] @ BI[b, b] + BI[b - 1, b] = A + BI[b - 1, b - 1] = -A @ dtY + BI[b, b - 1] = -BI[b, b] @ dtY # right for b in range(blocks[-1], nbm1): - A = - BI[b, b] @ dagger(tX[b]) - BI[b, b+1] = A - BI[b+1, b+1] = - tX[b] @ A - BI[b+1, b] = - tX[b] @ BI[b, b] + A = -BI[b, b] @ dagger(tX[b]) + BI[b, b + 1] = A + BI[b + 1, b + 1] = -tX[b] @ A + BI[b + 1, b] = -tX[b] @ BI[b, b] return BM def _scattering_state_reduce(self, elec, DOS, U, cutoff): - """ U on input is a fortran-index as returned from eigh or svd """ + """U on input is a fortran-index as returned from eigh or svd""" # Select only the first N components where N is the # number of orbitals in the electrode (there can't be # any more propagating states anyhow). @@ -1688,8 +1780,10 @@ def _scattering_state_reduce(self, elec, DOS, U, cutoff): return DOS[idx], U[:, idx] - def scattering_state(self, elec, E, k=(0, 0, 0), cutoff=0., method="svd:gamma", *args, **kwargs): - r""" Calculate the scattering states for a given `E` and `k` point from a given electrode + def scattering_state( + self, elec, E, k=(0, 0, 0), cutoff=0.0, method="svd:gamma", *args, **kwargs + ): + r"""Calculate the scattering states for a given `E` and `k` point from a given electrode The scattering states are the eigen states of the spectral function: @@ -1743,10 +1837,12 @@ def scattering_state(self, elec, E, k=(0, 0, 0), cutoff=0., method="svd:gamma", method = method.lower().replace(":", "_") func = getattr(self, f"_scattering_state_{method}", None) if func is None: - raise ValueError(f"{self.__class__.__name__}.scattering_state method is not [full,svd,propagate]") + raise ValueError( + f"{self.__class__.__name__}.scattering_state method is not [full,svd,propagate]" + ) return func(elec, cutoff, *args, **kwargs) - def _scattering_state_full(self, elec, cutoff=0., **kwargs): + def _scattering_state_full(self, elec, cutoff=0.0, **kwargs): # We know that scattering_state has called prepare! A = self.spectral(elec, self._data.E, self._data.k, **kwargs) @@ -1763,17 +1859,13 @@ def _scattering_state_full(self, elec, cutoff=0., **kwargs): data = self._data info = dict( - method="full", - elec=self._elec_name(elec), - E=data.E, - k=data.k, - cutoff=cutoff + method="full", elec=self._elec_name(elec), E=data.E, k=data.k, cutoff=cutoff ) # always have the first state with the largest values return si.physics.StateCElectron(A.T, DOS, self, **info) - def _scattering_state_svd_gamma(self, elec, cutoff=0., **kwargs): + def _scattering_state_svd_gamma(self, elec, cutoff=0.0, **kwargs): A = self._green_column(self.elecs_pvt_dev[elec].ravel()) # This calculation uses the cholesky decomposition of Gamma @@ -1790,7 +1882,7 @@ def _scattering_state_svd_gamma(self, elec, cutoff=0., **kwargs): elec=self._elec_name(elec), E=data.E, k=data.k, - cutoff=cutoff + cutoff=cutoff, ) # always have the first state with the largest values @@ -1844,26 +1936,26 @@ def _scattering_state_svd_a(self, elec, cutoff=0, **kwargs): # Create full U A = np.empty([len(self), u.shape[1]], dtype=u.dtype) - sl = slice(cbtd[blocks[0]], cbtd[blocks[0]+1]) - A[sl, :] = u[:self.btd[blocks[0]], :] + sl = slice(cbtd[blocks[0]], cbtd[blocks[0] + 1]) + A[sl, :] = u[: self.btd[blocks[0]], :] if len(blocks) > 1: - sl = slice(cbtd[blocks[1]], cbtd[blocks[1]+1]) - A[sl, :] = u[self.btd[blocks[0]]:, :] + sl = slice(cbtd[blocks[1]], cbtd[blocks[1] + 1]) + A[sl, :] = u[self.btd[blocks[0]] :, :] del u # Propagate A in the full BTD matrix t = self._data.tY - sl = slice(cbtd[blocks[0]], cbtd[blocks[0]+1]) + sl = slice(cbtd[blocks[0]], cbtd[blocks[0] + 1]) for b in range(blocks[0], 0, -1): - sln = slice(cbtd[b-1], cbtd[b]) - A[sln] = - t[b] @ A[sl] + sln = slice(cbtd[b - 1], cbtd[b]) + A[sln] = -t[b] @ A[sl] sl = sln t = self._data.tX - sl = slice(cbtd[blocks[-1]], cbtd[blocks[-1]+1]) + sl = slice(cbtd[blocks[-1]], cbtd[blocks[-1] + 1]) for b in range(blocks[-1], nb - 1): - slp = slice(cbtd[b+1], cbtd[b+2]) - A[slp] = - t[b] @ A[sl] + slp = slice(cbtd[b + 1], cbtd[b + 2]) + A[slp] = -t[b] @ A[sl] sl = slp # Perform svd @@ -1880,11 +1972,11 @@ def _scattering_state_svd_a(self, elec, cutoff=0, **kwargs): E=data.E, k=data.k, cutoff_elec=cutoff0, - cutoff=cutoff1 + cutoff=cutoff1, ) return si.physics.StateCElectron(A.T, DOS, self, **info) - def scattering_matrix(self, elec_from, elec_to, E, k=(0, 0, 0), cutoff=0.): + def scattering_matrix(self, elec_from, elec_to, E, k=(0, 0, 0), cutoff=0.0): r""" Calculate the scattering matrix (S-matrix) between `elec_from` and `elec_to` The scattering matrix is calculated as @@ -1964,8 +2056,7 @@ def calc_S(elec_from, jtgam_from, elec_to, tgam_to, G): return ret tgam_from = 1j * tG[elec_from] - S = tuple(calc_S(elec_from, tgam_from, elec, tG[elec], G) - for elec in elec_to) + S = tuple(calc_S(elec_from, tgam_from, elec, tG[elec], G) for elec in elec_to) if is_single: return S[0] @@ -2055,13 +2146,10 @@ def eigenchannel(self, state, elec_to, ret_coeff=False): tt, Ut = eigh_destroy(Ut) tt *= 2 * np.pi - info = {**state.info, - "elec_to": tuple(self._elec_name(e) for e in elec_to) - } + info = {**state.info, "elec_to": tuple(self._elec_name(e) for e in elec_to)} # Backtransform U to form the eigenchannels - teig = si.physics.StateCElectron(Ut[:, ::-1].T @ A, - tt[::-1], self, **info) + teig = si.physics.StateCElectron(Ut[:, ::-1].T @ A, tt[::-1], self, **info) if ret_coeff: return teig, si.physics.StateElectron(Ut[:, ::-1].T, self, **info) return teig diff --git a/src/sisl_toolbox/cli/__init__.py b/src/sisl_toolbox/cli/__init__.py index bc84a60fd8..e34afb296e 100644 --- a/src/sisl_toolbox/cli/__init__.py +++ b/src/sisl_toolbox/cli/__init__.py @@ -9,13 +9,13 @@ class SToolBoxCLI: - """ Run the CLI `stoolbox` """ + """Run the CLI `stoolbox`""" def __init__(self): self._cmds = [] def register(self, setup): - """ Register a setup callback function which creates the subparser + """Register a setup callback function which creates the subparser The ``setup(..)`` command must accept a sub-parser from `argparse` as its first argument. @@ -42,8 +42,9 @@ def __call__(self, argv=None): # Create command-line cmd = Path(sys.argv[0]) - p = argparse.ArgumentParser(f"{cmd.name}", - description="Specific toolboxes to aid sisl users") + p = argparse.ArgumentParser( + f"{cmd.name}", description="Specific toolboxes to aid sisl users" + ) info = { "title": "Toolboxes", diff --git a/src/sisl_toolbox/models/_base.py b/src/sisl_toolbox/models/_base.py index 81a170272c..da7c1d6951 100644 --- a/src/sisl_toolbox/models/_base.py +++ b/src/sisl_toolbox/models/_base.py @@ -5,24 +5,32 @@ class ModelDispatcher(ClassDispatcher): - """ Container for dispatch models """ + """Container for dispatch models""" + pass class BaseModel: - """ Base class used for inheritance for creating separate models """ - ref = ModelDispatcher("ref", - type_dispatcher=None, - obj_getattr=lambda obj, key: - (_ for _ in ()).throw( - AttributeError((f"{obj}.to does not implement '{key}' " - f"dispatcher, are you using it incorrectly?")) - ) + """Base class used for inheritance for creating separate models""" + + ref = ModelDispatcher( + "ref", + type_dispatcher=None, + obj_getattr=lambda obj, key: (_ for _ in ()).throw( + AttributeError( + ( + f"{obj}.to does not implement '{key}' " + f"dispatcher, are you using it incorrectly?" + ) + ) + ), ) + # Each model should inherit from this class ReferenceDispatch(AbstractDispatch): - """ Base dispatcher that implemnets different models """ + """Base dispatcher that implemnets different models""" + pass diff --git a/src/sisl_toolbox/models/_graphene/__init__.py b/src/sisl_toolbox/models/_graphene/__init__.py index 8c41692cc7..406e589fae 100644 --- a/src/sisl_toolbox/models/_graphene/__init__.py +++ b/src/sisl_toolbox/models/_graphene/__init__.py @@ -7,7 +7,7 @@ # Here we import the specific details that are exposed from ._hamiltonian import * -__all__ = ['graphene'] +__all__ = ["graphene"] # Define the graphene model graphene = PropertyDict() diff --git a/src/sisl_toolbox/models/_graphene/_base.py b/src/sisl_toolbox/models/_graphene/_base.py index d031fd9165..b8bd2cc60f 100644 --- a/src/sisl_toolbox/models/_graphene/_base.py +++ b/src/sisl_toolbox/models/_graphene/_base.py @@ -7,7 +7,6 @@ class GrapheneModel(BaseModel): - # copy the dispatcher method ref = BaseModel.ref.copy() @@ -16,7 +15,7 @@ class GrapheneModel(BaseModel): # The distances are here kept @classmethod def distance(cls, n=1, a=1.42): - """ Return the distance to the nearest neighbour according to the bond-length `a` + """Return the distance to the nearest neighbour according to the bond-length `a` Currently only up to 3rd nearest neighbour is implemneted @@ -28,9 +27,9 @@ def distance(cls, n=1, a=1.42): the bond length of the intrinsic graphene lattice """ dist = { - 0: 0., + 0: 0.0, 1: a, - 2: a * 3 ** 0.5, + 2: a * 3**0.5, 3: a * 2, } return dist[n] diff --git a/src/sisl_toolbox/models/_graphene/_hamiltonian.py b/src/sisl_toolbox/models/_graphene/_hamiltonian.py index 9939d650df..58f28de008 100644 --- a/src/sisl_toolbox/models/_graphene/_hamiltonian.py +++ b/src/sisl_toolbox/models/_graphene/_hamiltonian.py @@ -17,60 +17,64 @@ class GrapheneHamiltonian(GrapheneModel): class SimpleDispatch(ReferenceDispatch): - """ This implements the simple nearest neighbour TB model """ + """This implements the simple nearest neighbour TB model""" def dispatch(self, t=-2.7, a=1.42, orthogonal=False): # Define the graphene lattice da = 0.0005 - C = si.Atom(6, si.AtomicOrbital(n=2, l=1, m=0, R=a+da)) + C = si.Atom(6, si.AtomicOrbital(n=2, l=1, m=0, R=a + da)) graphene = si.geom.graphene(a, C, orthogonal=orthogonal) # Define the Hamiltonian H = si.Hamiltonian(graphene) - H.construct([(da, a+da), (0, t)]) + H.construct([(da, a + da), (0, t)]) return H + GrapheneHamiltonian.ref.register("simple", SimpleDispatch) class Hancock2010Dispatch(ReferenceDispatch): - """ Implementing reference models from 10.1103/PhysRevB.81.245402 """ + """Implementing reference models from 10.1103/PhysRevB.81.245402""" + doi = "10.1103/PhysRevB.81.245402" - def dispatch(self, set='A', a=1.42, orthogonal=False): + def dispatch(self, set="A", a=1.42, orthogonal=False): distance = self._obj.distance da = 0.0005 H_orthogonal = True - #U = 2.0 + # U = 2.0 - R = tuple(distance(i, a)+da for i in range(4)) - if set == 'A': + R = tuple(distance(i, a) + da for i in range(4)) + if set == "A": # same as simple t = (0, -2.7) - #U = 0. - elif set == 'B': + # U = 0. + elif set == "B": # same as simple t = (0, -2.7) - elif set == 'C': + elif set == "C": t = (0, -2.7, -0.2) - elif set == 'D': + elif set == "D": t = (0, -2.7, -0.2, -0.18) - elif set == 'E': + elif set == "E": # same as D, but specific for GNR t = (0, -2.7, -0.2, -0.18) - elif set == 'F': + elif set == "F": # same as D, but specific for GNR t = [(0, 1), (-2.7, 0.11), (-0.09, 0.045), (-0.27, 0.065)] H_orthogonal = False - elif set == 'G': + elif set == "G": # same as D, but specific for GNR t = [(0, 1), (-2.97, 0.073), (-0.073, 0.018), (-0.33, 0.026)] - #U = 0. + # U = 0. H_orthogonal = False else: - raise ValueError(f"Set specification for {self.doi} does not exist, should be one of [A-G]") + raise ValueError( + f"Set specification for {self.doi} does not exist, should be one of [A-G]" + ) # Reduce size of R - R = R[:len(t)] + R = R[: len(t)] # Currently we do not carry over U, since it is not specified for the # sisl objects.... @@ -84,12 +88,13 @@ def dispatch(self, set='A', a=1.42, orthogonal=False): H.construct([R, t]) return H + GrapheneHamiltonian.ref.register("Hancock2010", Hancock2010Dispatch) GrapheneHamiltonian.ref.register(Hancock2010Dispatch.doi, Hancock2010Dispatch) class Ishii2010Dispatch(ReferenceDispatch): - r""" Implementing reference model from 10.1103/PhysRevLett.104.116801 + r"""Implementing reference model from 10.1103/PhysRevLett.104.116801 Instead of using the :math:`\lambda_0` as parameter name, we use ``t`` for the coupling strength. @@ -100,11 +105,13 @@ def dispatch(self, t=-2.7, a=1.42, orthogonal=False): distance = self._obj.distance da = 0.0005 - R = (distance(0, a)+da, distance(1, a)+da) + R = (distance(0, a) + da, distance(1, a) + da) + def construct(H, ia, atoms, atoms_xyz=None): - idx_t01, rij_t01 = H.geometry.close(ia, R=R, - atoms=atoms, atoms_xyz=atoms_xyz, ret_rij=True) - H[ia, idx_t01[0]] = 0. + idx_t01, rij_t01 = H.geometry.close( + ia, R=R, atoms=atoms, atoms_xyz=atoms_xyz, ret_rij=True + ) + H[ia, idx_t01[0]] = 0.0 H[ia, idx_t01[1]] = t * (a / rij_t01[1]) ** 2 # Define the graphene lattice @@ -115,23 +122,29 @@ def construct(H, ia, atoms, atoms_xyz=None): H.construct(construct) return H + GrapheneHamiltonian.ref.register("Ishii2010", Ishii2010Dispatch) GrapheneHamiltonian.ref.register(Ishii2010Dispatch.doi, Ishii2010Dispatch) class Cummings2019Dispatch(ReferenceDispatch): - """ Implementing reference model from 10.1021/acs.nanolett.9b03112 """ + """Implementing reference model from 10.1021/acs.nanolett.9b03112""" + doi = "10.1021/acs.nanolett.9b03112" - def dispatch(self, t=(-2.414, -0.168), beta=(-1.847, -3.077), a=1.42, orthogonal=False): + def dispatch( + self, t=(-2.414, -0.168), beta=(-1.847, -3.077), a=1.42, orthogonal=False + ): distance = self._obj.distance da = 0.0005 - R = (distance(0, a)+da, distance(1, a)+da, distance(2, a)+da) + R = (distance(0, a) + da, distance(1, a) + da, distance(2, a) + da) + def construct(H, ia, atoms, atoms_xyz=None): - idx_t012, rij_t012 = H.geometry.close(ia, R=R, - atoms=atoms, atoms_xyz=atoms_xyz, ret_rij=True) - H[ia, idx_t012[0]] = 0. + idx_t012, rij_t012 = H.geometry.close( + ia, R=R, atoms=atoms, atoms_xyz=atoms_xyz, ret_rij=True + ) + H[ia, idx_t012[0]] = 0.0 H[ia, idx_t012[1]] = t[0] * np.exp(beta[0] * (rij_t012[1] - R[1])) H[ia, idx_t012[2]] = t[1] * np.exp(beta[1] * (rij_t012[2] - R[2])) @@ -143,19 +156,26 @@ def construct(H, ia, atoms, atoms_xyz=None): H.construct(construct) return H + GrapheneHamiltonian.ref.register("Cummings2019", Cummings2019Dispatch) GrapheneHamiltonian.ref.register(Cummings2019Dispatch.doi, Cummings2019Dispatch) class Wu2011Dispatch(ReferenceDispatch): - """ Implementing reference model from 10.1007/s11671-010-9791-y """ + """Implementing reference model from 10.1007/s11671-010-9791-y""" + doi = "10.1007/s11671-010-9791-y" def dispatch(self, a=1.42, orthogonal=False): distance = self._obj.distance da = 0.0005 - R = (distance(0, a)+da, distance(1, a)+da, distance(2, a)+da, distance(3, a)+da) + R = ( + distance(0, a) + da, + distance(1, a) + da, + distance(2, a) + da, + distance(3, a) + da, + ) # Define the graphene lattice C = si.Atom(6, si.AtomicOrbital(n=2, l=1, m=0, R=R[-1])) graphene = si.geom.graphene(a, C, orthogonal=orthogonal) @@ -165,5 +185,6 @@ def dispatch(self, a=1.42, orthogonal=False): H.construct([R, t]) return H + GrapheneHamiltonian.ref.register("Wu2011", Wu2011Dispatch) GrapheneHamiltonian.ref.register(Wu2011Dispatch.doi, Wu2011Dispatch) diff --git a/src/sisl_toolbox/siesta/atom/_atom.py b/src/sisl_toolbox/siesta/atom/_atom.py index 8f740d5f45..530d89eddd 100644 --- a/src/sisl_toolbox/siesta/atom/_atom.py +++ b/src/sisl_toolbox/siesta/atom/_atom.py @@ -65,15 +65,27 @@ # 19 | 7p | 6 | 118 | [Og] _shell_order = [ # Occupation shell order - '1s', # [He] - '2s', '2p', # [Ne] - '3s', '3p', # [Ar] - '3d', '4s', '4p', # [Kr] - '4d', '5s', '5p', # [Xe] - '4f', '5d', '6s', '6p', # [Rn] - '7s', '5f', '6d', '7p' # [Og] + "1s", # [He] + "2s", + "2p", # [Ne] + "3s", + "3p", # [Ar] + "3d", + "4s", + "4p", # [Kr] + "4d", + "5s", + "5p", # [Xe] + "4f", + "5d", + "6s", + "6p", # [Rn] + "7s", + "5f", + "6d", + "7p", # [Og] ] -_spdfgh = 'spdfgh' +_spdfgh = "spdfgh" class AtomInput: @@ -103,9 +115,9 @@ class AtomInput: .. [AtomLicense] https://siesta.icmab.es/SIESTA_MATERIAL/Pseudos/atom_licence.html """ - def __init__(self, atom, - define=('NEW_CC', 'FREE_FORMAT_RC_INPUT', 'NO_PS_CUTOFFS'), - **opts): + def __init__( + self, atom, define=("NEW_CC", "FREE_FORMAT_RC_INPUT", "NO_PS_CUTOFFS"), **opts + ): # opts = { # "flavor": "tm2", # "xc": "pb", @@ -127,39 +139,49 @@ def __init__(self, atom, l = 0 for orb in self.atom: if orb.l != l: - raise ValueError(f"{self.__class__.__name__} atom argument does not have " - f"increasing l quantum number index {l} has l={orb.l}") + raise ValueError( + f"{self.__class__.__name__} atom argument does not have " + f"increasing l quantum number index {l} has l={orb.l}" + ) l += 1 if l != 4: - raise ValueError(f"{self.__class__.__name__} atom argument must have 4 orbitals. " - f"One for each s-p-d-f shell") + raise ValueError( + f"{self.__class__.__name__} atom argument must have 4 orbitals. " + f"One for each s-p-d-f shell" + ) self.opts = PropertyDict(**opts) # Check options passed and define defaults self.opts.setdefault("equation", "r") - if self.opts.equation not in ' rs': + if self.opts.equation not in " rs": # ' ' == non-polarized # s == polarized # r == relativistic - raise ValueError(f"{self.__class__.__name__} failed to initialize; opts{'equation': } has wrong value, should be [ rs].") - if self.opts.equation == 's': - raise NotImplementedError(f"{self.__class__.__name__} does not implement spin-polarized option (use relativistic)") + raise ValueError( + f"{self.__class__.__name__} failed to initialize; opts{'equation': } has wrong value, should be [ rs]." + ) + if self.opts.equation == "s": + raise NotImplementedError( + f"{self.__class__.__name__} does not implement spin-polarized option (use relativistic)" + ) self.opts.setdefault("flavor", "tm2") - if self.opts.flavor not in ('hsc', 'ker', 'tm2'): + if self.opts.flavor not in ("hsc", "ker", "tm2"): # hsc == Hamann-Schluter-Chiang # ker == Kerker # tm2 == Troullier-Martins - raise ValueError(f"{self.__class__.__name__} failed to initialize; opts{'flavor': } has wrong value, should be [hsc|ker|tm2].") + raise ValueError( + f"{self.__class__.__name__} failed to initialize; opts{'flavor': } has wrong value, should be [hsc|ker|tm2]." + ) - self.opts.setdefault("logr", 2.) + self.opts.setdefault("logr", 2.0) # default to true if set self.opts.setdefault("cc", "rcore" in self.opts) # rcore only used if cc is True - self.opts.setdefault("rcore", 0.) + self.opts.setdefault("rcore", 0.0) self.opts.setdefault("xc", "pb") # Read in the core valence shells for this atom @@ -171,14 +193,19 @@ def __init__(self, atom, # _shell_order.index("2p") == 2 # which has 1s and 2s occupied. try: - core = reduce(min, (_shell_order.index(f"{orb.n}{_spdfgh[orb.l]}") - for orb in atom), len(_shell_order)) + core = reduce( + min, + (_shell_order.index(f"{orb.n}{_spdfgh[orb.l]}") for orb in atom), + len(_shell_order), + ) except Exception: core = -1 self.opts.setdefault("core", core) if self.opts.core == -1: - raise ValueError(f"Default value for {self.atom.symbol} not added, please add core= at instantiation") + raise ValueError( + f"Default value for {self.atom.symbol} not added, please add core= at instantiation" + ) # Store the defined names if define is None: @@ -190,16 +217,17 @@ def __init__(self, atom, @classmethod def from_input(cls, inp): - """ Return atom object respecting the input + """Return atom object respecting the input Parameters ---------- inp : list or str create `AtomInput` from the content of `inp` """ + def _get_content(f): if f.is_file(): - return open(f, 'r').readlines() + return open(f, "r").readlines() return None if isinstance(inp, (tuple, list)): @@ -214,7 +242,9 @@ def _get_content(f): if content is None: content = _get_content(inp / "INP") if content is None: - raise ValueError(f"Could not find any input file in {str(inp)} or {str(inp / 'INP')}") + raise ValueError( + f"Could not find any input file in {str(inp)} or {str(inp / 'INP')}" + ) inp = content else: @@ -300,8 +330,9 @@ def bypass(inp, defines): @classmethod def from_yaml(cls, file, nodes=()): - """ Parse the yaml file """ + """Parse the yaml file""" from sisl_toolbox.siesta.minimizer._yaml_reader import parse_variable, read_yaml + dic = read_yaml(file, nodes) element = dic["element"] @@ -316,7 +347,9 @@ def from_yaml(cls, file, nodes=()): opts["xc"] = pseudo.get("xc") opts["equation"] = pseudo.get("equation") opts["flavor"] = pseudo.get("flavor") - define = pseudo.get("define", ('NEW_CC', 'FREE_FORMAT_RC_INPUT', 'NO_PS_CUTOFFS')) + define = pseudo.get( + "define", ("NEW_CC", "FREE_FORMAT_RC_INPUT", "NO_PS_CUTOFFS") + ) # Now on to parsing the valence shells orbs = [] @@ -327,7 +360,7 @@ def from_yaml(cls, file, nodes=()): # Now we know the occupation is a shell pseudo = dic[key].get("pseudo", {}) cutoff = parse_variable(pseudo.get("cutoff"), 2.1, "Ang").value - charge = parse_variable(pseudo.get("charge"), 0.).value + charge = parse_variable(pseudo.get("charge"), 0.0).value orbs.append(si.AtomicOrbital(key, m=0, R=cutoff, q0=charge)) atom = si.Atom(element, orbs, mass=mass, tag=tag) @@ -348,7 +381,7 @@ def write_generation(self, f): pg = "pg" logr = self.opts.logr * _Ang2Bohr f.write(f" {pg:2s} {self.atom.symbol} pseudo potential\n") - if logr < 0.: + if logr < 0.0: f.write(f" {self.opts.flavor:3s}\n") else: f.write(f" {self.opts.flavor:3s}{logr:9.3f}\n") @@ -367,15 +400,17 @@ def write_generation(self, f): f.write(f"{core:5d}{valence:5d}\n") - Rs = [0.] * 4 # always 4: s, p, d, f + Rs = [0.0] * 4 # always 4: s, p, d, f for orb in sorted(atom.orbitals, key=lambda x: x.l): # Write the configuration of this orbital n, l = orb.n, orb.l f.write(f"{n:5d}{l:5d}{orb.q0:10.3f}{0.0:10.3f}\n") Rs[l] = orb.R * _Ang2Bohr - f.write(f"{Rs[0]:10.7f} {Rs[1]:10.7f} {Rs[2]:10.7f} {Rs[3]:10.7f} {0.0:10.7f} {rcore:10.7f}\n") + f.write( + f"{Rs[0]:10.7f} {Rs[1]:10.7f} {Rs[2]:10.7f} {Rs[3]:10.7f} {0.0:10.7f} {rcore:10.7f}\n" + ) - def write_all_electron(self, f, charges=(0.,)): + def write_all_electron(self, f, charges=(0.0,)): q0 = self.atom.q0.sum() xc = self.opts.xc equation = self.opts.equation @@ -431,9 +466,9 @@ def _get_out(self, path, filename): return Path(path) / Path(filename) def __call__(self, filename="INP", path=None): - """ Open a file and return self """ + """Open a file and return self""" out = self._get_out(path, filename) - self._enter = open(out, 'w'), self.atom + self._enter = open(out, "w"), self.atom return self def __enter__(self): @@ -457,29 +492,38 @@ def pg(self, filename="INP", path=None): self.write_generation(f) def excite(self, *charge, **lq): - """ Excite contained atom to another charge + """Excite contained atom to another charge Notes ----- This is charge, *not* electrons. """ if len(charge) > 1: - raise ValueError(f"{self.__class__.__name__}.excite takes only " - "a single argument or [spdf]=charge arguments") + raise ValueError( + f"{self.__class__.__name__}.excite takes only " + "a single argument or [spdf]=charge arguments" + ) elif len(charge) == 1: charge = charge[0] if "charge" in lq: - raise ValueError(f"{self.__class__.__name__}.excite does not accept " - "both charge as argument and keyword argument") + raise ValueError( + f"{self.__class__.__name__}.excite does not accept " + "both charge as argument and keyword argument" + ) else: charge = 0 charge = lq.pop("charge", charge) # get indices of the orders - shell_idx = [_shell_order.index(f"{orb.n}{_spdfgh[orb.l]}") - for orb in self.atom] + shell_idx = [ + _shell_order.index(f"{orb.n}{_spdfgh[orb.l]}") for orb in self.atom + ] + def _charge(idx): # while the g and h shells are never used, we just have them... - return {'s': 2, 'p': 6, 'd': 10, 'f': 14, 'g': 18, 'h': 22}[_shell_order[idx][1]] + return {"s": 2, "p": 6, "d": 10, "f": 14, "g": 18, "h": 22}[ + _shell_order[idx][1] + ] + orig_charge = [_charge(idx) for idx in shell_idx] atom = self.atom.copy() # find order of orbitals (from highest index to lowest) @@ -492,7 +536,7 @@ def _charge(idx): for orb in atom: if orb.l == l: orb._q0 -= q - assert orb.q0 >= 0. + assert orb.q0 >= 0.0 # now finalize the charge while abs(charge) > 1e-9: @@ -504,15 +548,19 @@ def _charge(idx): elif charge < 0 and orb.q0 < orig_charge[idx]: dq = max(orb.q0 - orig_charge[idx], charge) else: - dq = 0. + dq = 0.0 orb._q0 -= dq charge -= dq return self.__class__(atom, self.define, **self.opts) - def plot(self, path=None, - plot=('wavefunction', 'charge', 'log', 'potential'), - l='spdf', show=True): - """ Plot everything related to this psf file + def plot( + self, + path=None, + plot=("wavefunction", "charge", "log", "potential"), + l="spdf", + show=True, + ): + """Plot everything related to this psf file Parameters ---------- @@ -531,13 +579,14 @@ def plot(self, path=None, axs : axes used for plotting """ import matplotlib.pyplot as plt + if path is None: path = Path.cwd() else: path = Path(path) def get_xy(f, yfactors=None): - """ Return x, y data from file `f` with y being calculated as the factors between the columns """ + """Return x, y data from file `f` with y being calculated as the factors between the columns""" nonlocal path f = path / f if not f.is_file(): @@ -549,17 +598,22 @@ def get_xy(f, yfactors=None): if yfactors is None: yfactors = [0, 1] - yfactors = np.pad(yfactors, (0, ncol-len(yfactors)), constant_values=0.) + yfactors = np.pad(yfactors, (0, ncol - len(yfactors)), constant_values=0.0) x = data[:, 0] y = (data * yfactors.reshape(1, -1)).sum(1) return x, y l2i = { - 's': 0, 0: 0, - 'p': 1, 1: 1, - 'd': 2, 2: 2, - 'f': 3, 3: 3, - 'g': 4, 4: 4, # never used + "s": 0, + 0: 0, + "p": 1, + 1: 1, + "d": 2, + 2: 2, + "f": 3, + 3: 3, + "g": 4, + 4: 4, # never used } # Get this atoms default calculated binding length @@ -585,10 +639,10 @@ def plot_wavefunction(ax): r, w = get_xy(f"PSWFNR{il}") if not r is None: - ax.plot(r, w, '--', label=f"PS {_spdfgh[il]}") + ax.plot(r, w, "--", label=f"PS {_spdfgh[il]}") ax.set_xlim(0, atom_r * 5) - ax.autoscale(enable=True, axis='y', tight=True) + ax.autoscale(enable=True, axis="y", tight=True) ax.legend() def plot_charge(ax): @@ -605,20 +659,20 @@ def plot_charge(ax): color = p[0].get_color() if self.opts.get("cc", False): ax.axvline(self.opts.rcore * _Ang2Bohr, color=color, alpha=0.5) - ax.plot(ae_r, ae_vc, '--', label=f"AE valence") + ax.plot(ae_r, ae_vc, "--", label=f"AE valence") ps_r, ps_cc = get_xy("PSCHARGE", [0, 0, 0, 1]) _, ps_vc = get_xy("PSCHARGE", [0, 1, 1]) if not ps_r is None: - ax.plot(ps_r, ps_cc, '--', label=f"PS core") - ax.plot(ps_r, ps_vc, ':', label=f"PS valence") + ax.plot(ps_r, ps_cc, "--", label=f"PS core") + ax.plot(ps_r, ps_vc, ":", label=f"PS valence") # Now determine the overlap between all-electron core-charge # and the pseudopotential valence charge if np.allclose(ae_r, ps_r): # Determine dR - #dr = ae_r[1] - ae_r[0] + # dr = ae_r[1] - ae_r[0] # Integrate number of core-electrons and valence electrons core_c = np.trapz(ae_cc, ae_r) @@ -633,16 +687,18 @@ def plot_charge(ax): # the core one for r < r_pc. # Tests show that it might be located where the core charge density is from 1 to 2 times # larger than the valence charge density - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): fraction = ae_cc / ps_vc np.nan_to_num(fraction, copy=False) - ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis - ax2.plot(ae_r, fraction, 'k', alpha=0.5, label='c/v') + ax2 = ( + ax.twinx() + ) # instantiate a second axes that shares the same x-axis + ax2.plot(ae_r, fraction, "k", alpha=0.5, label="c/v") - marks = np.array([0.5, 1., 1.5]) + marks = np.array([0.5, 1.0, 1.5]) min_x = (ae_r > 0.2).nonzero()[0].min() - max_x = (fraction[min_x:] > 0.).nonzero()[0].max() + min_x + max_x = (fraction[min_x:] > 0.0).nonzero()[0].max() + min_x r_marks = interp1d(fraction[min_x:max_x], ae_r[min_x:max_x])(marks) ax2.scatter(r_marks, marks, alpha=0.5) @@ -669,8 +725,9 @@ def plot_log(ax): if not e is None: p = ax.plot(e, log, label=f"AE {_spdfgh[il]}") - idx_mark = (np.fabs(e.reshape(-1, 1) - emark.reshape(1, -1)) - .argmin(axis=0)) + idx_mark = np.fabs(e.reshape(-1, 1) - emark.reshape(1, -1)).argmin( + axis=0 + ) ax.scatter(emark, log[idx_mark], color=p[0].get_color(), alpha=0.5) # And now PS @@ -680,10 +737,11 @@ def plot_log(ax): emark.shape = (1, -1) emark = emark[:, 0] if not e is None: - p = ax.plot(e, log, ':', label=f"PS {_spdfgh[il]}") + p = ax.plot(e, log, ":", label=f"PS {_spdfgh[il]}") - idx_mark = (np.fabs(e.reshape(-1, 1) - emark.reshape(1, -1)) - .argmin(axis=0)) + idx_mark = np.fabs(e.reshape(-1, 1) - emark.reshape(1, -1)).argmin( + axis=0 + ) ax.scatter(emark, log[idx_mark], color=p[0].get_color(), alpha=0.5) ax.legend() @@ -738,7 +796,7 @@ def next_rc(ir, ic, nrows, ncols): def atom_plot_cli(subp=None): - """ Run plotting command for the output of atom """ + """Run plotting command for the output of atom""" is_sub = not subp is None @@ -749,23 +807,32 @@ def atom_plot_cli(subp=None): p = subp.add_parser("atom-plot", description=title, help=title) else: import argparse + p = argparse.ArgumentParser(title) - p.add_argument("--plot", '-P', action='append', type=str, - choices=('wavefunction', 'charge', 'log', 'potential'), - help="""Determine what to plot""") + p.add_argument( + "--plot", + "-P", + action="append", + type=str, + choices=("wavefunction", "charge", "log", "potential"), + help="""Determine what to plot""", + ) - p.add_argument("-l", default='spdf', type=str, - help="""Which l shells to plot""") + p.add_argument("-l", default="spdf", type=str, help="""Which l shells to plot""") - p.add_argument("--save", "-S", default=None, - help="""Save output plots to file.""") + p.add_argument("--save", "-S", default=None, help="""Save output plots to file.""") - p.add_argument("--show", default=False, action='store_true', - help="""Force showing the plot (only if --save is specified)""") + p.add_argument( + "--show", + default=False, + action="store_true", + help="""Force showing the plot (only if --save is specified)""", + ) - p.add_argument("input", type=str, default="INP", - help="""Input file name (default INP)""") + p.add_argument( + "input", type=str, default="INP", help="""Input file name (default INP)""" + ) if is_sub: p.set_defaults(runner=atom_plot) @@ -788,7 +855,7 @@ def atom_plot(args): # if users have not specified what to plot, we plot everything if args.plot is None: - args.plot = ('wavefunction', 'charge', 'log', 'potential') + args.plot = ("wavefunction", "charge", "log", "potential") fig = atom.plot(path, plot=args.plot, l=args.l, show=False)[0] if args.save is None: diff --git a/src/sisl_toolbox/siesta/minimizer/_atom_basis.py b/src/sisl_toolbox/siesta/minimizer/_atom_basis.py index 8edbd763c1..e9a20a3ce4 100644 --- a/src/sisl_toolbox/siesta/minimizer/_atom_basis.py +++ b/src/sisl_toolbox/siesta/minimizer/_atom_basis.py @@ -19,7 +19,7 @@ class AtomBasis: - """Basis block format for Siesta """ + """Basis block format for Siesta""" def __init__(self, atom, opts=None): # opts = {(n, l): # or n=1, l=0 1s @@ -43,7 +43,9 @@ def __init__(self, atom, opts=None): else: self.opts = opts if not isinstance(self.opts, dict): - raise ValueError(f"{self.__class__.__name__} must get `opts` as a dictionary argument") + raise ValueError( + f"{self.__class__.__name__} must get `opts` as a dictionary argument" + ) # Assert that we have options corresonding to the orbitals present for key in self.opts.keys(): @@ -56,7 +58,9 @@ def __init__(self, atom, opts=None): if orb.n == n and orb.l == l: found = True if not found: - raise ValueError("Options passed for n={n} l={l}, but no orbital with that signiture is present?") + raise ValueError( + "Options passed for n={n} l={l}, but no orbital with that signiture is present?" + ) # ensure each orbital has an option associated for (n, l), orbs in self.yield_nl_orbs(): @@ -64,13 +68,14 @@ def __init__(self, atom, opts=None): @classmethod def from_dict(cls, dic): - """ Return an `AtomBasis` from a dictionary + """Return an `AtomBasis` from a dictionary Parameters ---------- dic : dict """ from sisl_toolbox.siesta.atom._atom import _shell_order + element = dic["element"] tag = dic.get("tag") mass = dic.get("mass", None) @@ -105,16 +110,16 @@ def get_radius(orbs, zeta): if key in ("charge-confinement", "charge-conf"): opt_nl["charge"] = [ parse_variable(entry.get("charge")).value, - parse_variable(entry.get("yukawa"), unit='1/Ang').value, - parse_variable(entry.get("width"), unit='Ang').value + parse_variable(entry.get("yukawa"), unit="1/Ang").value, + parse_variable(entry.get("width"), unit="Ang").value, ] elif key in ("soft-confinement", "soft-conf"): opt_nl["soft"] = [ - parse_variable(entry.get("V0"), unit='eV').value, - parse_variable(entry.get("ri"), unit='Ang').value + parse_variable(entry.get("V0"), unit="eV").value, + parse_variable(entry.get("ri"), unit="Ang").value, ] elif key in ("filter",): - opt_nl["filter"] = parse_variable(entry, unit='eV').value + opt_nl["filter"] = parse_variable(entry, unit="eV").value elif key in ("split-norm", "split"): opt_nl["split"] = parse_variable(entry).value elif key in ("polarization", "pol"): @@ -122,9 +127,9 @@ def get_radius(orbs, zeta): elif key.startswith("zeta"): # cutoff of zeta zeta = int(key[4:]) - R = parse_variable(entry, unit='Ang').value + R = parse_variable(entry, unit="Ang").value if R < 0: - R *= -get_radius(orbs_nl, zeta-1) + R *= -get_radius(orbs_nl, zeta - 1) orbs_nl.append(si.AtomicOrbital(n=n, l=l, m=0, zeta=zeta, R=R)) if len(orbs_nl) > 0: @@ -136,13 +141,14 @@ def get_radius(orbs, zeta): @classmethod def from_yaml(cls, file, nodes=()): - """ Parse the yaml file """ + """Parse the yaml file""" from ._yaml_reader import read_yaml + return cls.from_dict(read_yaml(file, nodes)) @classmethod def from_block(cls, block): - """ Return an `Atom` for a specified basis block + """Return an `Atom` for a specified basis block Parameters ---------- @@ -160,7 +166,7 @@ def blockline(): nonlocal block out = "" while len(out) == 0: - out = block.pop(0).split('#')[0].strip() + out = block.pop(0).split("#")[0].strip() return out # define global opts @@ -196,7 +202,7 @@ def blockline(): # This is because the first n= should never # contain a ".", whereas the contraction *should*. if len(block) > 0: - if '.' in block[0].split()[0]: + if "." in block[0].split()[0]: contract_line = blockline() # remove n= @@ -261,7 +267,9 @@ def blockline(): # calculate the radius pass else: - raise ValueError(f"Could not parse the PAO.Basis block for the zeta ranges {rc_line}.") + raise ValueError( + f"Could not parse the PAO.Basis block for the zeta ranges {rc_line}." + ) orb = si.AtomicOrbital(n=n, l=l, m=0, zeta=izeta, R=rc) nzeta -= 1 orbs.append(orb) @@ -271,7 +279,7 @@ def blockline(): # useful to leave the rc's definitions out. rc = orbs[-1].R for izeta in range(nzeta): - orb = si.AtomicOrbital(n=n, l=l, m=0, zeta=orbs[-1].zeta+1, R=rc) + orb = si.AtomicOrbital(n=n, l=l, m=0, zeta=orbs[-1].zeta + 1, R=rc) orbs.append(orb) opts[(n, l)] = nlopts @@ -280,7 +288,7 @@ def blockline(): return cls(atom, opts) def yield_nl_orbs(self): - """ An iterator with each different ``n, l`` pair returned with a list of zeta-shells """ + """An iterator with each different ``n, l`` pair returned with a list of zeta-shells""" orbs = {} for orb in self.atom: # build a dictionary @@ -291,7 +299,7 @@ def yield_nl_orbs(self): yield from orbs.items() def basis(self): - """ Get basis block lines (as list)""" + """Get basis block lines (as list)""" block = [] @@ -360,17 +368,16 @@ def basis(self): # to be sure) orbs_sorted = sorted(orbs, key=lambda orb: orb.zeta) - line = " ".join(map(lambda orb: f"{orb.R*_Ang2Bohr:.10f}", - orbs_sorted)) + line = " ".join(map(lambda orb: f"{orb.R*_Ang2Bohr:.10f}", orbs_sorted)) block.append(line) # We don't need the 1's, they are contraction factors # and we simply keep them the default values - #line = " ".join(map(lambda orb: "1.0000", orbs)) - #block.append(line) + # line = " ".join(map(lambda orb: "1.0000", orbs)) + # block.append(line) return block def get_variables(self, dict_or_yaml, nodes=()): - """ Convert a dictionary or yaml file input to variables usable by the minimizer """ + """Convert a dictionary or yaml file input to variables usable by the minimizer""" if not isinstance(dict_or_yaml, dict): dict_or_yaml = read_yaml(dict_or_yaml) if isinstance(nodes, str): @@ -380,17 +387,17 @@ def get_variables(self, dict_or_yaml, nodes=()): return self._get_variables_dict(dict_or_yaml) def _get_variables_dict(self, dic): - """ Parse a dictionary adding potential variables to the minimize model """ + """Parse a dictionary adding potential variables to the minimize model""" tag = self.atom.tag # with respect to the basis def update_orb(old, new, orb): - """ Update an orbital's radius """ + """Update an orbital's radius""" orb._R = new # Define other options def update(old, new, d, key, index=None): - """ An updater for a dictionary with optional keys """ + """An updater for a dictionary with optional keys""" if index is None: d[key] = new else: @@ -398,6 +405,7 @@ def update(old, new, d, key, index=None): # returned variables V = [] + def add_variable(var): nonlocal V if var.value is not None: @@ -409,8 +417,13 @@ def add_variable(var): # get default options for pseudo basis = dic.get("basis", {}) - add_variable(parse_variable(basis.get("ion-charge"), name=f"{tag}.ion-q", - update_func=partial(update, d=self.opts, key="ion_charge"))) + add_variable( + parse_variable( + basis.get("ion-charge"), + name=f"{tag}.ion-q", + update_func=partial(update, d=self.opts, key="ion_charge"), + ) + ) # parse depending on shells in the atom spdf = "spdfgh" @@ -425,30 +438,79 @@ def add_variable(var): for flag in ("charge-confinement", "charge-conf"): # Now parse this one d = basis.get(flag, {}) - for var in [parse_variable(d.get("charge"), name=f"{tag}.{nl}.charge.q", - update_func=partial(update, d=self.opts[(n, l)], key="charge", index=0)), - parse_variable(d.get("yukawa"), unit='1/Ang', name=f"{tag}.{nl}.charge.yukawa", - update_func=partial(update, d=self.opts[(n, l)], key="charge", index=1)), - parse_variable(d.get("width"), unit='Ang', name=f"{tag}.{nl}.charge.width", - update_func=partial(update, d=self.opts[(n, l)], key="charge", index=2))]: + for var in [ + parse_variable( + d.get("charge"), + name=f"{tag}.{nl}.charge.q", + update_func=partial( + update, d=self.opts[(n, l)], key="charge", index=0 + ), + ), + parse_variable( + d.get("yukawa"), + unit="1/Ang", + name=f"{tag}.{nl}.charge.yukawa", + update_func=partial( + update, d=self.opts[(n, l)], key="charge", index=1 + ), + ), + parse_variable( + d.get("width"), + unit="Ang", + name=f"{tag}.{nl}.charge.width", + update_func=partial( + update, d=self.opts[(n, l)], key="charge", index=2 + ), + ), + ]: add_variable(var) for flag in ("soft-confinement", "soft-conf"): # Now parse this one d = basis.get(flag, {}) - for var in [parse_variable(d.get("V0"), unit='eV', name=f"{tag}.{nl}.soft.V0", - update_func=partial(update, d=self.opts[(n, l)], key="soft", index=0)), - parse_variable(d.get("ri"), unit='Ang', name=f"{tag}.{nl}.soft.ri", - update_func=partial(update, d=self.opts[(n, l)], key="soft", index=1))]: + for var in [ + parse_variable( + d.get("V0"), + unit="eV", + name=f"{tag}.{nl}.soft.V0", + update_func=partial( + update, d=self.opts[(n, l)], key="soft", index=0 + ), + ), + parse_variable( + d.get("ri"), + unit="Ang", + name=f"{tag}.{nl}.soft.ri", + update_func=partial( + update, d=self.opts[(n, l)], key="soft", index=1 + ), + ), + ]: add_variable(var) - add_variable(parse_variable(basis.get("filter"), unit='eV', name=f"{tag}.{nl}.filter", - update_func=partial(update, d=self.opts[(n, l)], key="filter"))) + add_variable( + parse_variable( + basis.get("filter"), + unit="eV", + name=f"{tag}.{nl}.filter", + update_func=partial(update, d=self.opts[(n, l)], key="filter"), + ) + ) for flag in ("split-norm", "split"): - add_variable(parse_variable(basis.get(flag), name=f"{tag}.{nl}.split", - update_func=partial(update, d=self.opts[(n, l)], key="split"))) - - add_variable(parse_variable(basis.get(f"zeta{orb.zeta}"), name=f"{tag}.{nl}.z{orb.zeta}", - update_func=partial(update_orb, orb=orb))) + add_variable( + parse_variable( + basis.get(flag), + name=f"{tag}.{nl}.split", + update_func=partial(update, d=self.opts[(n, l)], key="split"), + ) + ) + + add_variable( + parse_variable( + basis.get(f"zeta{orb.zeta}"), + name=f"{tag}.{nl}.z{orb.zeta}", + update_func=partial(update_orb, orb=orb), + ) + ) return V diff --git a/src/sisl_toolbox/siesta/minimizer/_atom_pseudo.py b/src/sisl_toolbox/siesta/minimizer/_atom_pseudo.py index b79de1fd05..5a8b6deffb 100644 --- a/src/sisl_toolbox/siesta/minimizer/_atom_pseudo.py +++ b/src/sisl_toolbox/siesta/minimizer/_atom_pseudo.py @@ -16,9 +16,8 @@ class AtomPseudo(AtomInput): - def get_variables(self, dict_or_yaml, nodes=()): - """ Convert a dictionary or yaml file input to variables usable by the minimizer """ + """Convert a dictionary or yaml file input to variables usable by the minimizer""" if not isinstance(dict_or_yaml, dict): dict_or_yaml = read_yaml(dict_or_yaml) if isinstance(nodes, str): @@ -28,17 +27,17 @@ def get_variables(self, dict_or_yaml, nodes=()): return self._get_variables_dict(dict_or_yaml) def _get_variables_dict(self, dic): - """ Parse a dictionary adding potential variables to the minimize model """ + """Parse a dictionary adding potential variables to the minimize model""" tag = self.atom.tag # with respect to the basis def update_orb(old, new, orb, key): - """ Update an orbital's radii """ + """Update an orbital's radii""" setattr(orb, f"_{key}", new) # Define other options def update(old, new, d, key, idx=None): - """ An updater for a dictionary with optional keys """ + """An updater for a dictionary with optional keys""" if idx is None: d[key] = new else: @@ -46,6 +45,7 @@ def update(old, new, d, key, idx=None): # returned variables V = [] + def add_variable(var): nonlocal V if var.value is not None: @@ -55,11 +55,24 @@ def add_variable(var): # get default options for pseudo pseudo = dic.get("pseudo", {}) - add_variable(parse_variable(pseudo.get("log-radii"), unit="Ang", name=f"{tag}.logr", - update_func=partial(update, d=self.opts, key="logr"))) - - add_variable(parse_variable(pseudo.get("core-correction"), 0., unit="Ang", name=f"{tag}.core", - update_func=partial(update, d=self.opts, key="rcore"))) + add_variable( + parse_variable( + pseudo.get("log-radii"), + unit="Ang", + name=f"{tag}.logr", + update_func=partial(update, d=self.opts, key="logr"), + ) + ) + + add_variable( + parse_variable( + pseudo.get("core-correction"), + 0.0, + unit="Ang", + name=f"{tag}.core", + update_func=partial(update, d=self.opts, key="rcore"), + ) + ) # parse depending on shells in the atom for orb in self.atom: @@ -70,10 +83,23 @@ def add_variable(var): continue # Now parse this one - add_variable(parse_variable(pseudo.get("cutoff"), orb.R, unit="Ang", name=f"{tag}.{nl}.r", - update_func=partial(update_orb, orb=orb, key="R"))) - - add_variable(parse_variable(pseudo.get("charge"), orb.q0, name=f"{tag}.{nl}.q", - update_func=partial(update_orb, orb=orb, key="q0"))) + add_variable( + parse_variable( + pseudo.get("cutoff"), + orb.R, + unit="Ang", + name=f"{tag}.{nl}.r", + update_func=partial(update_orb, orb=orb, key="R"), + ) + ) + + add_variable( + parse_variable( + pseudo.get("charge"), + orb.q0, + name=f"{tag}.{nl}.q", + update_func=partial(update_orb, orb=orb, key="q0"), + ) + ) return V diff --git a/src/sisl_toolbox/siesta/minimizer/_metric.py b/src/sisl_toolbox/siesta/minimizer/_metric.py index 85261d7a4e..6f6a25bc84 100644 --- a/src/sisl_toolbox/siesta/minimizer/_metric.py +++ b/src/sisl_toolbox/siesta/minimizer/_metric.py @@ -7,10 +7,9 @@ class Metric: - @abstractmethod def metric(self, variables, *args, **kwargs): - """ Return a single number quantifying the metric of the system """ + """Return a single number quantifying the metric of the system""" def __abs__(self): return AbsMetric(self) @@ -50,7 +49,7 @@ def max(self, other): class CompositeMetric(Metric): - """ Placeholder for two metrics """ + """Placeholder for two metrics""" def __init__(self, A, B): self.A = A diff --git a/src/sisl_toolbox/siesta/minimizer/_metric_siesta.py b/src/sisl_toolbox/siesta/minimizer/_metric_siesta.py index c9c52c4fc7..76a5645a76 100644 --- a/src/sisl_toolbox/siesta/minimizer/_metric_siesta.py +++ b/src/sisl_toolbox/siesta/minimizer/_metric_siesta.py @@ -13,7 +13,13 @@ from ._metric import Metric from ._path import path_rel_or_abs -__all__ = ["SiestaMetric", "EnergyMetric", "EigenvalueMetric", "ForceMetric", "StressMetric"] +__all__ = [ + "SiestaMetric", + "EnergyMetric", + "EigenvalueMetric", + "ForceMetric", + "StressMetric", +] _log = logging.getLogger("sisl_toolbox.siesta.minimize") @@ -39,44 +45,50 @@ def _siesta_out_accept(out): class SiestaMetric(Metric): - """ Generic Siesta metric + """Generic Siesta metric Since in some cases siesta may crash we need to have *failure* metrics that returns if siesta fails to run. """ - def __init__(self, failure=0.): + def __init__(self, failure=0.0): if isinstance(failure, Number): + def func(metric, fail): if fail: return failure return metric + self.failure = func elif callable(failure): self.failure = failure else: - raise ValueError(f"{self.__class__.__name__} could not initialize failure, not number or callable") + raise ValueError( + f"{self.__class__.__name__} could not initialize failure, not number or callable" + ) class EigenvalueMetric(SiestaMetric): - """ Compare eigenvalues between two calculations and return the difference as the metric """ + """Compare eigenvalues between two calculations and return the difference as the metric""" - def __init__(self, eig_file, eig_ref, dist=None, align_valence=False, failure=0.): - """ Store the reference eigenvalues along the distribution (if any) """ + def __init__(self, eig_file, eig_ref, dist=None, align_valence=False, failure=0.0): + """Store the reference eigenvalues along the distribution (if any)""" super().__init__(failure) self.eig_file = path_rel_or_abs(eig_file) # we copy to ensure users don't change these self.eig_ref = eig_ref.copy() if dist is None: - self.dist = 1. + self.dist = 1.0 elif callable(dist): self.dist = dist(eig_ref) else: try: eig_ref * dist except Exception: - raise ValueError(f"{self.__class__.__name__} was passed `dist` which was not " - "broadcastable to `eig_ref`. Please ensure compatibility.") + raise ValueError( + f"{self.__class__.__name__} was passed `dist` which was not " + "broadcastable to `eig_ref`. Please ensure compatibility." + ) self.dist = dist.copy() # whether we should align the valence band edges @@ -84,27 +96,29 @@ def __init__(self, eig_file, eig_ref, dist=None, align_valence=False, failure=0. self.align_valence = align_valence def metric(self, variables): - """ Compare eigenvalues with a reference eigenvalue set, scaled by dist """ + """Compare eigenvalues with a reference eigenvalue set, scaled by dist""" try: eig = io_siesta.eigSileSiesta(self.eig_file).read_data() - eig = eig[:, :, :self.eig_ref.shape[2]] + eig = eig[:, :, : self.eig_ref.shape[2]] if self.align_valence: # align data at the valence band ( - eig -= eig[eig < 0.].max() + eig -= eig[eig < 0.0].max() # Calculate the metric, also average around k-points - metric = (((eig - self.eig_ref) * self.dist) ** 2).sum() ** 0.5 / eig.shape[1] + metric = (((eig - self.eig_ref) * self.dist) ** 2).sum() ** 0.5 / eig.shape[ + 1 + ] metric = self.failure(metric, False) _log.debug(f"metric.eigenvalue [{self.eig_file}] success {metric}") except Exception: - metric = self.failure(0., True) + metric = self.failure(0.0, True) _log.warning(f"metric.eigenvalue [{self.eig_file}] fail {metric}") return metric class EnergyMetric(SiestaMetric): - """ Metric is the energy (default total), read from the output file + """Metric is the energy (default total), read from the output file Alternatively the metric could be any operation of the energies that is returned. @@ -121,35 +135,40 @@ class EnergyMetric(SiestaMetric): in case the output does not contain anything runner fails, then we should return a "fake" metric. """ - def __init__(self, out, energy='total', failure=0.): + def __init__(self, out, energy="total", failure=0.0): super().__init__(failure) self.out = path_rel_or_abs(out) if isinstance(energy, str): energy_str = energy.split(".") + def energy(energy_dict): - f""" {'.'.join(energy_str)} metric """ + """energy metric""" for sub in energy_str[:-1]: energy_dict = energy_dict[sub] return energy_dict[energy_str[-1]] + # TODO fix documentation to f""" {'.'.join(energy_str)} metric """ + if not callable(energy): - raise ValueError(f"{self.__class__.__name__} requires energy to be callable or str") + raise ValueError( + f"{self.__class__.__name__} requires energy to be callable or str" + ) self.energy = energy def metric(self, variables): - """ Read the energy from the out file in `path` """ + """Read the energy from the out file in `path`""" out = io_siesta.outSileSiesta(self.out) if _siesta_out_accept(out): metric = self.failure(self.energy(out.read_energy()), False) _log.debug(f"metric.energy [{self.out}] success {metric}") else: - metric = self.failure(0., True) + metric = self.failure(0.0, True) _log.warning(f"metric.energy [{self.out}] fail {metric}") return metric class ForceMetric(SiestaMetric): - """ Metric is the force (default maximum), read from the FA file + """Metric is the force (default maximum), read from the FA file Alternatively the metric could be any operation on the forces. @@ -166,38 +185,43 @@ class ForceMetric(SiestaMetric): in case the output does not contain anything runner fails, then we should return a "fake" metric. """ - def __init__(self, file, force='abs.max', failure=0.): + def __init__(self, file, force="abs.max", failure=0.0): super().__init__(failure) self.file = path_rel_or_abs(file) if isinstance(force, str): force_op = force.split(".") + def force(forces): - f""" {force_op} metric """ + """force metric""" out = forces for op in force_op: if op == "l2": - out = (out ** 2).sum(-1) ** 0.5 + out = (out**2).sum(-1) ** 0.5 else: out = getattr(np, op)(out) return out + + # TODO fix documentation f""" {force_op} metric """ if not callable(force): - raise ValueError(f"{self.__class__.__name__} requires force to be callable or str") + raise ValueError( + f"{self.__class__.__name__} requires force to be callable or str" + ) self.force = force def metric(self, variables): - """ Read the force from the `self.file` in `path` """ + """Read the force from the `self.file` in `path`""" try: force = self.force(get_sile(self.file).read_force()) metric = self.failure(force, False) _log.debug(f"metric.force [{self.file}] success {metric}") except Exception: - metric = self.failure(0., True) + metric = self.failure(0.0, True) _log.debug(f"metric.force [{self.file}] fail {metric}") return metric class StressMetric(SiestaMetric): - """ Metric is the stress tensor, read from the output file + """Metric is the stress tensor, read from the output file Parameters ---------- @@ -210,26 +234,31 @@ class StressMetric(SiestaMetric): in case the output does not contain anything runner fails, then we should return a "fake" metric. """ - def __init__(self, out, stress='ABC', failure=2.): + def __init__(self, out, stress="ABC", failure=2.0): super().__init__(failure) self.out = path_rel_or_abs(out) if isinstance(stress, str): stress_directions = list(map(direction, stress)) + def stress(stress_matrix): - f""" {stress_directions} metric """ + """stress metric""" return stress_matrix[stress_directions, stress_directions].sum() + + # TODO fix documentation f""" {stress_directions} metric """ if not callable(stress): - raise ValueError(f"{self.__class__.__name__} requires stress to be callable") + raise ValueError( + f"{self.__class__.__name__} requires stress to be callable" + ) self.stress = stress def metric(self, variables): - """ Convert the stress-tensor to a single metric that should be minimized """ + """Convert the stress-tensor to a single metric that should be minimized""" out = io_siesta.outSileSiesta(self.out) if _siesta_out_accept(out): stress = self.stress(out.read_stress()) metric = self.failure(stress, False) _log.debug(f"metric.stress [{self.out}] success {metric}") else: - metric = self.failure(0., True) + metric = self.failure(0.0, True) _log.warning(f"metric.stress [{self.out}] fail {metric}") return metric diff --git a/src/sisl_toolbox/siesta/minimizer/_minimize.py b/src/sisl_toolbox/siesta/minimizer/_minimize.py index 8e8aea1559..e1157586ce 100644 --- a/src/sisl_toolbox/siesta/minimizer/_minimize.py +++ b/src/sisl_toolbox/siesta/minimizer/_minimize.py @@ -15,14 +15,18 @@ from sisl.io import tableSile from sisl.utils import PropertyDict -__all__ = ["BaseMinimize", "LocalMinimize", "DualAnnealingMinimize", - "MinimizeToDispatcher"] +__all__ = [ + "BaseMinimize", + "LocalMinimize", + "DualAnnealingMinimize", + "MinimizeToDispatcher", +] _log = logging.getLogger("sisl_toolbox.siesta.minimize") def _convert_optimize_result(minimizer, result): - """ Convert optimize result to conform to the scaling procedure performed """ + """Convert optimize result to conform to the scaling procedure performed""" # reverse optimized value # and also store the normalized values (to match the gradients etc) if minimizer.norm[0] in ("none", "identity"): @@ -36,17 +40,17 @@ def _convert_optimize_result(minimizer, result): # The jacobian is dM / dx with dx possibly being scaled # So here we change multiply by dx / dv result.jac_norm = result.jac.copy() - result.jac /= minimizer.reverse_normalize(np.ones(len(minimizer)), - with_offset=False) + result.jac /= minimizer.reverse_normalize( + np.ones(len(minimizer)), with_offset=False + ) return result class BaseMinimize: - # Basic minimizer basically used for figuring out whether # to use a local or global minimization strategy - def __init__(self, variables=(), out="minimize.dat", norm='identity'): + def __init__(self, variables=(), out="minimize.dat", norm="identity"): # ensure we have an ordered dict, for one reason or the other self.variables = [] if variables is not None: @@ -55,7 +59,7 @@ def __init__(self, variables=(), out="minimize.dat", norm='identity'): self.reset(out, norm) def reset(self, out=None, norm=None): - """ Reset data table to be able to restart """ + """Reset data table to be able to restart""" # While this *could* be a named-tuple, we would not be able # to override the attribute, hence we use a property dict # same effect. @@ -71,7 +75,7 @@ def reset(self, out=None, norm=None): if not norm is None: log += f" norm={str(norm)}" if isinstance(norm, str): - self.norm = (norm, 1.) + self.norm = (norm, 1.0) elif isinstance(norm, Real): self.norm = ("l2", norm) else: @@ -84,11 +88,15 @@ def normalize(self, variables, with_offset=True): # of each variable out = np.empty(len(self.variables)) for i, v in enumerate(self.variables): - out[i] = v.normalize(v.attrs[variables], self.norm, with_offset=with_offset) + out[i] = v.normalize( + v.attrs[variables], self.norm, with_offset=with_offset + ) else: out = np.empty_like(variables) for i, v in enumerate(variables): - out[i] = self.variables[i].normalize(v, self.norm, with_offset=with_offset) + out[i] = self.variables[i].normalize( + v, self.norm, with_offset=with_offset + ) return out def normalize_bounds(self): @@ -98,7 +106,9 @@ def reverse_normalize(self, variables, with_offset=True): # ensures numpy array out = np.empty_like(variables) for i, v in enumerate(variables): - out[i] = self.variables[i].reverse_normalize(v, self.norm, with_offset=with_offset) + out[i] = self.variables[i].reverse_normalize( + v, self.norm, with_offset=with_offset + ) return out def __getitem__(self, key): @@ -122,26 +132,30 @@ def values(self): return np.array([v.value for v in self.variables], np.float64) def update(self, variables): - """ Update internal variables for the values """ + """Update internal variables for the values""" for var, v in zip(self.variables, variables): var.update(v) def dict_values(self): - """ Get all vaules in a dictionary table """ + """Get all vaules in a dictionary table""" return {v.name: v.value for v in self.variables} # Define a dispatcher for converting Minimize data to some specific data # BaseMinimize().to.skopt() will convert to an skopt.OptimizationResult structure - to = ClassDispatcher("to", - obj_getattr=lambda obj, key: - (_ for _ in ()).throw( - AttributeError((f"{obj}.to does not implement '{key}' " - f"dispatcher, are you using it incorrectly?")) - ) + to = ClassDispatcher( + "to", + obj_getattr=lambda obj, key: (_ for _ in ()).throw( + AttributeError( + ( + f"{obj}.to does not implement '{key}' " + f"dispatcher, are you using it incorrectly?" + ) + ) + ), ) def __enter__(self): - """ Open the file and fill with stuff """ + """Open the file and fill with stuff""" _log.debug(f"__enter__ {self.__class__.__name__}") # check if the file exists @@ -156,7 +170,9 @@ def __enter__(self): if self.out.is_file() and data.size > 0: nvars = data.shape[0] - 1 if nvars != len(self): - raise ValueError(f"Found old file {self.out} which contains previous data for another number of parameters, please delete or move file") + raise ValueError( + f"Found old file {self.out} which contains previous data for another number of parameters, please delete or move file" + ) # now parse header *header, _ = header[1:].split() @@ -172,7 +188,9 @@ def __enter__(self): print(header) print(self.names) print(idx) - raise ValueError(f"Found old file {self.out} which contains previous data with some variables being renamed, please correct header or move file") + raise ValueError( + f"Found old file {self.out} which contains previous data with some variables being renamed, please correct header or move file" + ) # add functional value, no pivot idx.append(len(self)) @@ -193,19 +211,19 @@ def __enter__(self): comment = f"Created by sisl '{self.__class__.__name__}'." header = self.names + ["metric"] if len(self.data.x) == 0: - self._fh = tableSile(self.out, 'w').__enter__() + self._fh = tableSile(self.out, "w").__enter__() self._fh.write_data(comment=comment, header=header) else: comment += f" The first {len(self.data)} lines contains prior content." data = np.column_stack((self.data.x, self.data.y)) - self._fh = tableSile(self.out, 'w').__enter__() - self._fh.write_data(data.T, comment=comment, header=header, fmt='20.17e') + self._fh = tableSile(self.out, "w").__enter__() + self._fh.write_data(data.T, comment=comment, header=header, fmt="20.17e") self._fh.flush() return self def __exit__(self, *args, **kwargs): - """ Exit routine """ + """Exit routine""" self._fh.__exit__(*args, **kwargs) # clean-up del self._fh @@ -215,7 +233,7 @@ def __len__(self): @abstractmethod def __call__(self, variables, *args): - """ Actual running code that takes `variables` conforming to the order of initial setup. + """Actual running code that takes `variables` conforming to the order of initial setup. It will return the functional of the minimize method @@ -226,7 +244,7 @@ def __call__(self, variables, *args): """ def _minimize_func(self, norm_variables, *args): - """ Minimization function passed to the minimization method + """Minimization function passed to the minimization method This is a wrapper which does 3 things: @@ -256,7 +274,9 @@ def _minimize_func(self, norm_variables, *args): try: idx = self.data.hash.index(current_hash) # immediately return functional value that is hashed - _log.info(f"{self.__class__.__name__}._minimize_func, using prior hashed calculation {idx}") + _log.info( + f"{self.__class__.__name__}._minimize_func, using prior hashed calculation {idx}" + ) return self.data.y[idx] except ValueError: @@ -266,7 +286,9 @@ def _minimize_func(self, norm_variables, *args): # Else we have to call minimize metric = np.array(self(variables, *args)) # add the data to the output file and hash it - self._fh.write_data(variables.reshape(-1, 1), metric.reshape(-1, 1), fmt='20.17e') + self._fh.write_data( + variables.reshape(-1, 1), metric.reshape(-1, 1), fmt="20.17e" + ) self._fh.flush() self.data.x.append(variables) self.data.y.append(metric) @@ -276,38 +298,37 @@ def _minimize_func(self, norm_variables, *args): @abstractmethod def run(self, *args, **kwargs): - """ Run the minimize model """ + """Run the minimize model""" class LocalMinimize(BaseMinimize): - def run(self, *args, **kwargs): # Run minimization (always with normalized values) norm_v0 = self.normalize(self.values) bounds = self.normalize_bounds() with self: - opt = minimize(self._minimize_func, - x0=norm_v0, args=args, bounds=bounds, - **kwargs) + opt = minimize( + self._minimize_func, x0=norm_v0, args=args, bounds=bounds, **kwargs + ) return _convert_optimize_result(self, opt) class DualAnnealingMinimize(BaseMinimize): - def run(self, *args, **kwargs): # Run minimization (always with normalized values) norm_v0 = self.normalize(self.values) bounds = self.normalize_bounds() with self: - opt = dual_annealing(self._minimize_func, - x0=norm_v0, args=args, bounds=bounds, - **kwargs) + opt = dual_annealing( + self._minimize_func, x0=norm_v0, args=args, bounds=bounds, **kwargs + ) return _convert_optimize_result(self, opt) class MinimizeToDispatcher(AbstractDispatch): - """ Base dispatcher from class passing from Minimize class """ + """Base dispatcher from class passing from Minimize class""" + @staticmethod def _ensure_object(obj): if isinstance(obj, type): @@ -317,15 +338,19 @@ def _ensure_object(obj): class MinimizeToskoptDispatcher(MinimizeToDispatcher): def dispatch(self, *args, **kwargs): import skopt + minim = self._obj self._ensure_object(minim) if len(args) > 0: - raise ValueError(f"{minim.__class__.__name__}.to.skopt only accepts keyword arguments") + raise ValueError( + f"{minim.__class__.__name__}.to.skopt only accepts keyword arguments" + ) # First create the Space variable def skoptReal(v): low, high = v.bounds return skopt.space.Real(low, high, transform="identity", name=v.name) + space = skopt.Space(list(map(skoptReal, self.variables))) # Extract sampled data-points @@ -341,19 +366,22 @@ def skoptReal(v): # We can't use categorial (SVC) since these are regression models # fast, but should not be as accurate? - #model = sklearn.svm.LinearSVR() + # model = sklearn.svm.LinearSVR() # much slower, but more versatile # I don't know which one is better ;) model = sklearn.svm.SVR(cache_size=500) - #model = sklearn.svm.NuSVR(kernel="poly", cache_size=500) + # model = sklearn.svm.NuSVR(kernel="poly", cache_size=500) # we need to fit to create auxiliary data - warnings.warn(f"Converting to skopt without a 'models' argument forces " - f"{minim.__class__.__name__} to train a model for the sampled data. " - f"This may be slow depending on the number of samples...") + warnings.warn( + f"Converting to skopt without a 'models' argument forces " + f"{minim.__class__.__name__} to train a model for the sampled data. " + f"This may be slow depending on the number of samples..." + ) model.fit(Xi, yi) kwargs["models"] = [model] result = skopt.utils.create_result(Xi, yi, space=space, **kwargs) return result + BaseMinimize.to.register("skopt", MinimizeToskoptDispatcher) diff --git a/src/sisl_toolbox/siesta/minimizer/_minimize_siesta.py b/src/sisl_toolbox/siesta/minimizer/_minimize_siesta.py index 3b4a8a5fdd..6a12817898 100644 --- a/src/sisl_toolbox/siesta/minimizer/_minimize_siesta.py +++ b/src/sisl_toolbox/siesta/minimizer/_minimize_siesta.py @@ -21,8 +21,8 @@ _log = logging.getLogger("sisl_toolbox.siesta.minimize") -class MinimizeSiesta(BaseMinimize): # no inheritance! - """ A minimize minimizer for siesta (PP and basis or what-ever) +class MinimizeSiesta(BaseMinimize): # no inheritance! + """A minimize minimizer for siesta (PP and basis or what-ever) It is important that a this gets initialized with ``runner`` and ``metric`` keyword arguments. @@ -34,7 +34,8 @@ def __init__(self, runner, metric, *args, **kwargs): self.metric = metric def get_constraints(self, factor=0.95): - """ Return contraints for the zeta channels """ + """Return contraints for the zeta channels""" + # Now we define the constraints of the orbitals. def unpack(name): try: @@ -47,16 +48,12 @@ def unpack(name): except Exception: return None, None, None, None - orb_R = {} # (n,l) = {1: idx-nlzeta=1, 2: idx-nlzeta=2} + orb_R = {} # (n,l) = {1: idx-nlzeta=1, 2: idx-nlzeta=2} for i, v in enumerate(self.variables): symbol, n, l, zeta = unpack(v.name) if symbol is None: continue - (orb_R - .setdefault(symbol, {}) - .setdefault((n, l), {}) - .update({zeta: i}) - ) + (orb_R.setdefault(symbol, {}).setdefault((n, l), {}).update({zeta: i})) def assert_bounds(i1, i2): v1 = self.variables[i1] @@ -64,7 +61,9 @@ def assert_bounds(i1, i2): b1 = v1.bounds b2 = v2.bounds if not np.allclose(b1, b2): - raise ValueError("Bounds for zeta must be the same due to normalization") + raise ValueError( + "Bounds for zeta must be the same due to normalization" + ) # get two lists of neighbouring zeta's # Our constraint is that zeta cutoffs should be descending. @@ -74,7 +73,7 @@ def assert_bounds(i1, i2): for z_idx in atom.values(): for i in range(2, max(z_idx.keys()) + 1): # this will request zeta-indices in order (zeta1, zeta2, ...) - zeta1.append(z_idx[i-1]) + zeta1.append(z_idx[i - 1]) zeta2.append(z_idx[i]) assert_bounds(zeta1[-1], zeta2[-1]) @@ -87,6 +86,7 @@ def fun(v): # an inequality constraint must return a non-negative # zeta1.R * `factor` - zeta2.R >= 0. return v[zeta1] * factor - v[zeta2] + return fun def jac_factory(factor, zeta1, zeta2): @@ -98,19 +98,22 @@ def jac(v): out[idx, zeta1] = factor out[idx, zeta2] = -1 return out + return jac constr = [] - constr.append({ - "type": "ineq", - "fun": fun_factory(factor, zeta1, zeta2), - "jac": jac_factory(factor, zeta1, zeta2), - }) + constr.append( + { + "type": "ineq", + "fun": fun_factory(factor, zeta1, zeta2), + "jac": jac_factory(factor, zeta1, zeta2), + } + ) return constr def candidates(self, delta=1e-2, target=None, sort="max"): - """ Compare samples and find candidates within a delta-metric of `delta` + """Compare samples and find candidates within a delta-metric of `delta` Candidiates are ordered around the basis-set sizes. This means that *zeta* variables are the only ones used for figuring out @@ -141,8 +144,7 @@ def candidates(self, delta=1e-2, target=None, sort="max"): ytarget = y[idx_target] # Now find all valid samples - valid = np.logical_and(ytarget - delta <= y, - y <= ytarget + delta).nonzero()[0] + valid = np.logical_and(ytarget - delta <= y, y <= ytarget + delta).nonzero()[0] # Reduce to candidate points x_valid = x[valid] @@ -166,7 +168,9 @@ def candidates(self, delta=1e-2, target=None, sort="max"): # no need for sqrt (does nothing for sort) idx_increasing = np.argsort((x_valid[:, idx_R] ** 2).sum(axis=1)) else: - raise ValueError(f"{self.__class__.__name__}.candidates got an unknown value for 'sort={sort}', must be one of [max,l1,l2].") + raise ValueError( + f"{self.__class__.__name__}.candidates got an unknown value for 'sort={sort}', must be one of [max,l1,l2]." + ) else: # it really has to be callable ;) idx_increasing = sort(x_valid, y_valid) diff --git a/src/sisl_toolbox/siesta/minimizer/_runner.py b/src/sisl_toolbox/siesta/minimizer/_runner.py index dc5c601a59..bd5a286d94 100644 --- a/src/sisl_toolbox/siesta/minimizer/_runner.py +++ b/src/sisl_toolbox/siesta/minimizer/_runner.py @@ -12,9 +12,17 @@ from ._path import path_abs, path_rel_or_abs -__all__ = ["AbstractRunner", "AndRunner", "PathRunner", - "CleanRunner", "CopyRunner", "CommandRunner", - "AtomRunner", "SiestaRunner", "FunctionRunner"] +__all__ = [ + "AbstractRunner", + "AndRunner", + "PathRunner", + "CleanRunner", + "CopyRunner", + "CommandRunner", + "AtomRunner", + "SiestaRunner", + "FunctionRunner", +] _log = logging.getLogger("sisl_toolbox.siesta.minimize") @@ -26,7 +34,7 @@ def commonprefix(*paths): class AbstractRunner(ABC): - """ Define a runner """ + """Define a runner""" def __iter__(self): yield self @@ -36,11 +44,11 @@ def __and__(self, other): @abstractmethod def run(self, *args, **kwargs): - """ Run this runner """ + """Run this runner""" class AndRunner(AbstractRunner): - """ Placeholder for two runners """ + """Placeholder for two runners""" def __init__(self, A, B): self.A = A @@ -52,7 +60,7 @@ def __iter__(self): yield from self.B def run(self, A=None, B=None, **kwargs): - """ Run `self.A` first, then `self.B` + """Run `self.A` first, then `self.B` Both runners get ``kwargs`` as arguments, and `A` only gets passed to `self.A`. @@ -68,7 +76,7 @@ def run(self, A=None, B=None, **kwargs): ------- tuple of return from `self.A` and `self.B` """ - #print("running A") + # print("running A") if A is None: A = self.A.run(**kwargs) else: @@ -77,7 +85,7 @@ def run(self, A=None, B=None, **kwargs): A = self.A.run(**kw) if not isinstance(self.A, AndRunner): A = (A,) - #print("running B") + # print("running B") if B is None: B = self.B.run(**kwargs) else: @@ -91,7 +99,7 @@ def run(self, A=None, B=None, **kwargs): class PathRunner(AbstractRunner): - """ Define a runner """ + """Define a runner""" def __init__(self, path): self.path = path_abs(path) @@ -145,9 +153,13 @@ def __init__(self, from_path, to_path, *files, **rename): self.files = files self.rename = rename if not self.path.is_dir(): - raise ValueError(f"{self.__class__.__name__} path={self.path} must be a directory") + raise ValueError( + f"{self.__class__.__name__} path={self.path} must be a directory" + ) if not self.to.is_dir(): - raise ValueError(f"{self.__class__.__name__} path={self.to} must be a directory") + raise ValueError( + f"{self.__class__.__name__} path={self.to} must be a directory" + ) def run(self): copy = [] @@ -179,13 +191,17 @@ def run(self): class CommandRunner(PathRunner): - def __init__(self, path, cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, hook=None): + def __init__( + self, path, cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, hook=None + ): super().__init__(path) abs_cmd = path_abs(cmd, self.path) if abs_cmd.is_file(): self.cmd = [abs_cmd] if not os.access(self.cmd, os.X_OK): - raise ValueError(f"{self.__class__.__name__} shell script {self.cmd.relative_to(self.path.cwd())} not executable") + raise ValueError( + f"{self.__class__.__name__} shell script {self.cmd.relative_to(self.path.cwd())} not executable" + ) else: self.cmd = cmd.split() @@ -199,18 +215,20 @@ def __init__(self, path, cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ho self.stderr = path_rel_or_abs(stderr, self.path) if hook is None: + def hook(subprocess_output): return subprocess_output + assert callable(hook) self.hook = hook def _get_standard(self): out = self.stdout if isinstance(out, (Path, str)): - out = open(out, 'w') + out = open(out, "w") err = self.stderr if isinstance(err, (Path, str)): - err = open(err, 'w') + err = open(err, "w") return out, err def run(self): @@ -220,12 +238,20 @@ def run(self): # We need to clean the directory so that subsequent VPSFMT users don't # accidentially use a prior output stdout, stderr = self._get_standard() - return self.hook(subprocess.run(cmd, cwd=self.path, encoding='utf-8', - stdout=stdout, stderr=stderr, check=False)) + return self.hook( + subprocess.run( + cmd, + cwd=self.path, + encoding="utf-8", + stdout=stdout, + stderr=stderr, + check=False, + ) + ) class AtomRunner(CommandRunner): - """ Run a command with atom-input file as first argument and output file as second argument + """Run a command with atom-input file as first argument and output file as second argument This is tailored for atom in the sense of arguments for this class, but not restricted in any way. @@ -242,7 +268,15 @@ class AtomRunner(CommandRunner): ... # cd .. """ - def __init__(self, path, cmd="atom", input="INP", stdout=subprocess.PIPE, stderr=subprocess.PIPE, hook=None): + def __init__( + self, + path, + cmd="atom", + input="INP", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + hook=None, + ): super().__init__(path, cmd, stdout, stderr, hook) self.input = path_rel_or_abs(input, self.path) @@ -254,12 +288,20 @@ def run(self): # accidentially use a prior output self.clean("RHO", "OUT", "PS*", "AE*", "CHARGE", "COREQ", "FOURIER*", "VPS*") stdout, stderr = self._get_standard() - return self.hook(subprocess.run(cmd, cwd=self.path, encoding='utf-8', - stdout=stdout, stderr=stderr, check=False)) + return self.hook( + subprocess.run( + cmd, + cwd=self.path, + encoding="utf-8", + stdout=stdout, + stderr=stderr, + check=False, + ) + ) class SiestaRunner(CommandRunner): - """ Run a script/cmd with fdf as first argument and output file as second argument + """Run a script/cmd with fdf as first argument and output file as second argument This is tailored for Siesta in the sense of arguments for this class, but not restricted in any way. @@ -274,17 +316,27 @@ class SiestaRunner(CommandRunner): ... # cd .. """ - def __init__(self, path, cmd="siesta", fdf="RUN.fdf", stdout="RUN.out", stderr=subprocess.PIPE, hook=None): + def __init__( + self, + path, + cmd="siesta", + fdf="RUN.fdf", + stdout="RUN.out", + stderr=subprocess.PIPE, + hook=None, + ): super().__init__(path, cmd, stdout, stderr, hook) self.fdf = path_rel_or_abs(fdf, self.path) fdf = self.absattr("fdf") - self.systemlabel = fdfSileSiesta(fdf, base=self.path).get("SystemLabel", "siesta") + self.systemlabel = fdfSileSiesta(fdf, base=self.path).get( + "SystemLabel", "siesta" + ) def run(self): pipe = "" stdout, stderr = self._get_standard() - for pre, f in [('>', stdout), ('2>', stderr)]: + for pre, f in [(">", stdout), ("2>", stderr)]: try: pipe += f"{pre} {f.name}" except Exception: @@ -293,12 +345,20 @@ def run(self): _log.debug(f"running Siesta using command[{self.path}]: {' '.join(cmd)} {pipe}") # Remove stuff to ensure that we don't read information from prior calculations self.clean("*.ion*", "fdf-*.log", f"{self.systemlabel}.*") - return self.hook(subprocess.run(cmd, cwd=self.path, encoding='utf-8', - stdout=stdout, stderr=stderr, check=False)) + return self.hook( + subprocess.run( + cmd, + cwd=self.path, + encoding="utf-8", + stdout=stdout, + stderr=stderr, + check=False, + ) + ) class FunctionRunner(AbstractRunner): - """ Run a method `func` with specified arguments and kwargs """ + """Run a method `func` with specified arguments and kwargs""" def __init__(self, func, *args, **kwargs): self.func = func diff --git a/src/sisl_toolbox/siesta/minimizer/_variable.py b/src/sisl_toolbox/siesta/minimizer/_variable.py index 4f060f1a24..03047fa0de 100644 --- a/src/sisl_toolbox/siesta/minimizer/_variable.py +++ b/src/sisl_toolbox/siesta/minimizer/_variable.py @@ -5,11 +5,11 @@ import numpy as np -__all__ = ['Parameter', 'Variable', 'UpdateVariable'] +__all__ = ["Parameter", "Variable", "UpdateVariable"] class Parameter: - """ A parameter which is static and not changing. + """A parameter which is static and not changing. Parameters ---------- @@ -37,11 +37,11 @@ def update(self, value): self.value = value def normalize(self, value, *args, **kwargs): - """ Return normalized value """ + """Return normalized value""" return value def reverse_normalize(self, value, *args, **kwargs): - """ Return normalized value """ + """Return normalized value""" return value def __str__(self): @@ -54,7 +54,7 @@ def __eq__(self, other): class Variable(Parameter): - """ A minimization variable with associated name, inital value, and possible bounds. + """A minimization variable with associated name, inital value, and possible bounds. Parameters ---------- @@ -78,28 +78,28 @@ def __str__(self): return f"{self.__class__.__name__}{{name: {self.name}, value: {self.value}, bounds: {self.bounds}}}" def _parse_norm(self, norm, with_offset): - """ Return offset, scale factor """ + """Return offset, scale factor""" if isinstance(norm, str): - scale = 1. + scale = 1.0 elif isinstance(norm, Iterable): norm, scale = norm else: scale = norm - norm = 'l2' + norm = "l2" if with_offset: off = self.bounds[0] else: - off = 0. + off = 0.0 norm = norm.lower() - if norm in ('none', 'identity'): + if norm in ("none", "identity"): # a norm of none will never scale, nor offset - return 0., 1. - elif norm == 'l2': + return 0.0, 1.0 + elif norm == "l2": return off, scale / (self.bounds[1] - self.bounds[0]) raise ValueError("norm not found in [none/identity, l2]") - def normalize(self, value, norm='l2', with_offset=True): - """ Normalize a value in terms of the norms of this variable + def normalize(self, value, norm="l2", with_offset=True): + """Normalize a value in terms of the norms of this variable Parameters ---------- @@ -113,8 +113,8 @@ def normalize(self, value, norm='l2', with_offset=True): offset, fac = self._parse_norm(norm, with_offset) return (value - offset) * fac - def reverse_normalize(self, value, norm='l2', with_offset=True): - """ Revert what `normalize` does + def reverse_normalize(self, value, norm="l2", with_offset=True): + """Revert what `normalize` does Parameters ---------- @@ -135,7 +135,7 @@ def __init__(self, name, value, bounds, func, **attrs): self._func = func def update(self, value): - """ Also run update wrapper call for the new value + """Also run update wrapper call for the new value The update routine should have this interface: diff --git a/src/sisl_toolbox/siesta/minimizer/_yaml_reader.py b/src/sisl_toolbox/siesta/minimizer/_yaml_reader.py index ffe9acde2c..5fad46668e 100644 --- a/src/sisl_toolbox/siesta/minimizer/_yaml_reader.py +++ b/src/sisl_toolbox/siesta/minimizer/_yaml_reader.py @@ -12,7 +12,7 @@ def read_yaml(file, nodes=()): - """ Reads a yaml-file and returns the dictionary for the yaml-file + """Reads a yaml-file and returns the dictionary for the yaml-file Parameters ---------- @@ -21,7 +21,7 @@ def read_yaml(file, nodes=()): nodes : iterable, optional extract the node in a consecutive manner """ - dic = yaml.load(open(file, 'r'), Loader=Loader) + dic = yaml.load(open(file, "r"), Loader=Loader) if isinstance(nodes, str): nodes = [nodes] for node in nodes: @@ -30,7 +30,7 @@ def read_yaml(file, nodes=()): def parse_value(value, unit=None, value_unit=None): - """ Converts a float/str to proper value """ + """Converts a float/str to proper value""" if isinstance(value, str): value, value_unit = value.split() if len(value_unit) == 0: @@ -44,8 +44,8 @@ def parse_value(value, unit=None, value_unit=None): return value * units(value_unit, unit) -def parse_variable(value, default=None, unit=None, name='', update_func=None): - """ Parse a value to either a `Parameter` or `Variable` with defaults and units +def parse_variable(value, default=None, unit=None, name="", update_func=None): + """Parse a value to either a `Parameter` or `Variable` with defaults and units Parameters ---------- @@ -62,9 +62,10 @@ def parse_variable(value, default=None, unit=None, name='', update_func=None): the update function to be called if required """ from ._variable import Parameter, UpdateVariable, Variable + attrs = {} if unit is not None: - attrs = {'unit': unit} + attrs = {"unit": unit} if isinstance(value, dict): value_unit = value.get("unit", unit) diff --git a/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py b/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py index e03c5792f6..43b261f1f5 100644 --- a/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py +++ b/src/sisl_toolbox/transiesta/poisson/fftpoisson_fix.py @@ -56,7 +56,7 @@ import sisl as si -__all__ = ['pyamg_solve', 'solve_poisson', 'fftpoisson_fix_cli', 'fftpoisson_fix_run'] +__all__ = ["pyamg_solve", "solve_poisson", "fftpoisson_fix_cli", "fftpoisson_fix_run"] _BC = si.BoundaryCondition @@ -71,31 +71,54 @@ def pyamg_solve(A, b, tolerance=1e-12, accel=None, title=""): import pyamg + print(f"\nSetting up pyamg solver... {title}") ml = pyamg.aggregation.smoothed_aggregation_solver(A, max_levels=1000) del A print(ml) residuals = [] + def callback(x): # residuals calculated in the solve function is a pre-conditioned residual - #residuals.append(np.linalg.norm(b - A.dot(x)) ** 0.5) - print(" {:4d} residual = {:.5e} x0-residual = {:.5e}".format(len(residuals) - 1, residuals[-1], residuals[-1] / residuals[0])) - x = ml.solve(b, tol=tolerance, callback=callback, residuals=residuals, - accel=accel, cycle='W', maxiter=1e7) - print('Done solving the Poisson equation!') + # residuals.append(np.linalg.norm(b - A.dot(x)) ** 0.5) + print( + " {:4d} residual = {:.5e} x0-residual = {:.5e}".format( + len(residuals) - 1, residuals[-1], residuals[-1] / residuals[0] + ) + ) + + x = ml.solve( + b, + tol=tolerance, + callback=callback, + residuals=residuals, + accel=accel, + cycle="W", + maxiter=1e7, + ) + print("Done solving the Poisson equation!") return x -def solve_poisson(geometry, shape, radius="empirical", - dtype=np.float64, tolerance=1e-8, - accel=None, boundary_fft=True, - device_val=None, plot_boundary=False, - box=False, boundary=None, **elecs_V): - """ Solve Poisson equation """ +def solve_poisson( + geometry, + shape, + radius="empirical", + dtype=np.float64, + tolerance=1e-8, + accel=None, + boundary_fft=True, + device_val=None, + plot_boundary=False, + box=False, + boundary=None, + **elecs_V, +): + """Solve Poisson equation""" error = False elecs = [] for name in geometry.names: - if ('+' in name) or (name in ["Buffer", "Device"]): + if ("+" in name) or (name in ["Buffer", "Device"]): continue # This is actually an electrode @@ -110,24 +133,37 @@ def solve_poisson(geometry, shape, radius="empirical", for name in elecs: if not name in elecs_V: print(f" missing electrode bias: {name}") - raise ValueError(f"{_script}: Missing electrode arguments for specifying the bias.") + raise ValueError( + f"{_script}: Missing electrode arguments for specifying the bias." + ) if boundary is None: bc = [[_BC.PERIODIC, _BC.PERIODIC] for _ in range(3)] else: bc = [] + def bc2bc(s): - return {'periodic': 'PERIODIC', 'p': 'PERIODIC', _BC.PERIODIC: 'PERIODIC', - 'dirichlet': 'DIRICHLET', 'd': 'DIRICHLET', _BC.DIRICHLET: 'DIRICHLET', - 'neumann': 'NEUMANN', 'n': 'NEUMANN', _BC.NEUMANN: 'NEUMANN', + return { + "periodic": "PERIODIC", + "p": "PERIODIC", + _BC.PERIODIC: "PERIODIC", + "dirichlet": "DIRICHLET", + "d": "DIRICHLET", + _BC.DIRICHLET: "DIRICHLET", + "neumann": "NEUMANN", + "n": "NEUMANN", + _BC.NEUMANN: "NEUMANN", }.get(s.lower(), s.upper()) + for bottom, top in boundary: bc.append([getattr(_BC, bc2bc(bottom)), getattr(_BC, bc2bc(top))]) if len(bc) != 3: - raise ValueError(f"{_script}: Requires a 3x2 list input for the boundary conditions.") + raise ValueError( + f"{_script}: Requires a 3x2 list input for the boundary conditions." + ) def _create_shape_tree(xyz, A, B=None): - """ Takes two lists A and B which returns a shape with a binary nesting + """Takes two lists A and B which returns a shape with a binary nesting This makes further index handling much faster. """ @@ -159,10 +195,12 @@ def _create_shape_tree(xyz, A, B=None): # Create grid geometry.set_boundary_condition(bc) grid = si.Grid(shape, geometry=geometry, dtype=dtype) + class _fake: @property def shape(self): return shape + @property def dtype(self): return dtype @@ -206,8 +244,13 @@ def dtype(self): grid.grid = b.reshape(shape) del A else: - x = pyamg_solve(A, b, tolerance=tolerance, accel=accel, - title="solving electrode boundary conditions") + x = pyamg_solve( + A, + b, + tolerance=tolerance, + accel=accel, + title="solving electrode boundary conditions", + ) grid.grid = x.reshape(shape) del A, b @@ -217,8 +260,10 @@ def dtype(self): # This ensures that once we set the boundaries we don't # get any side-effects BC = si.BoundaryCondition - periodic = [bc == BC.PERIODIC or geometry.nsc[i] > 1 - for i, bc in enumerate(grid.lattice.boundary_condition[:, 0])] + periodic = [ + bc == BC.PERIODIC or geometry.nsc[i] > 1 + for i, bc in enumerate(grid.lattice.boundary_condition[:, 0]) + ] bc = np.repeat(np.array([BC.DIRICHLET], np.int32), 6).reshape(3, 2) for i in (0, 1, 2): if periodic[i]: @@ -240,15 +285,20 @@ def sl2idx(grid, sl): new_sl = sl[:] new_sl[i] = slice(0, 1) idx = sl2idx(grid, new_sl) - grid.pyamg_fix(A, b, idx, grid.grid[new_sl[0], new_sl[1], new_sl[2]].reshape(-1)) + grid.pyamg_fix( + A, b, idx, grid.grid[new_sl[0], new_sl[1], new_sl[2]].reshape(-1) + ) new_sl[i] = slice(grid.shape[i] - 1, grid.shape[i]) idx = sl2idx(grid, new_sl) - grid.pyamg_fix(A, b, idx, grid.grid[new_sl[0], new_sl[1], new_sl[2]].reshape(-1)) + grid.pyamg_fix( + A, b, idx, grid.grid[new_sl[0], new_sl[1], new_sl[2]].reshape(-1) + ) if plot_boundary: dat = b.reshape(*grid.shape) # now plot every plane import matplotlib.pyplot as plt + slicex3 = np.index_exp[:] * 3 axs = [ np.linspace(0, grid.lattice.length[ax], shape, endpoint=False) @@ -275,8 +325,13 @@ def sl2idx(grid, sl): plt.show() grid.grid = _fake() - x = pyamg_solve(A, b, tolerance=tolerance, accel=accel, - title="removing electrode boundaries and solving for edge fixing") + x = pyamg_solve( + A, + b, + tolerance=tolerance, + accel=accel, + title="removing electrode boundaries and solving for edge fixing", + ) grid.grid = x.reshape(shape) del A, b @@ -295,61 +350,157 @@ def fftpoisson_fix_cli(subp=None): else: p = argp.ArgumentParser(title) - tuning = p.add_argument_group("tuning", "Tuning fine details of the Poisson calculation.") - - p.add_argument("--geometry", "-G", default="siesta.TBT.nc", metavar="FILE", - help="siesta.TBT.nc file which contains the geometry and electrode information, currently we cannot read that from fdf-files.") - - p.add_argument("--shape", "-s", nargs=3, type=int, required=True, metavar=("A", "B", "C"), - help="Grid shape, this *has* to be conforming to the TranSiesta calculation, read from output: 'InitMesh: MESH = A x B x C'") + tuning = p.add_argument_group( + "tuning", "Tuning fine details of the Poisson calculation." + ) + + p.add_argument( + "--geometry", + "-G", + default="siesta.TBT.nc", + metavar="FILE", + help="siesta.TBT.nc file which contains the geometry and electrode information, currently we cannot read that from fdf-files.", + ) + + p.add_argument( + "--shape", + "-s", + nargs=3, + type=int, + required=True, + metavar=("A", "B", "C"), + help="Grid shape, this *has* to be conforming to the TranSiesta calculation, read from output: 'InitMesh: MESH = A x B x C'", + ) n = {"a": "first", "b": "second", "c": "third"} for d in "abc": - p.add_argument(f"--boundary-condition-{d}", f"-bc-{d}", nargs=2, type=str, default=["p", "p"], - metavar=("BOTTOM", "TOP"), - help=("Boundary condition along the {} lattice vector [periodic/p, neumann/n, dirichlet/d]. " - "Specify separate BC at the start and end of the lattice vector, respectively.".format(n[d]))) - - p.add_argument("--elec-V", "-V", action="append", nargs=2, metavar=("NAME", "V"), default=[], - help="Specify chemical potential on electrode") - - p.add_argument("--pyamg-shape", "-ps", nargs=3, type=int, metavar=("A", "B", "C"), default=None, - help="Grid used to solve the Poisson equation, if shape is different the Grid will be interpolated (order=2) after.") - - p.add_argument("--device", "-D", type=float, default=None, metavar="VAL", - help="Fix the value of all device atoms to a value. In some cases this turns out to yield a better box boundary. The default is to *not* fix the potential on the device atoms.") - - tuning.add_argument("--radius", "-R", type=float, default=3., metavar="R", - help=("Radius of atoms when figuring out the electrode sizes, this corresponds to the extend of " - "each electrode where boundary conditions are fixed. Should be tuned according to the atomic species [3 Ang]")) - - tuning.add_argument("--dtype", "-d", choices=["d", "f64", "f", "f32"], default="d", - help="Precision of data (d/f64==double, f/f32==single)") - - tuning.add_argument("--tolerance", "-T", type=float, default=1e-7, metavar="EPS", - help="Precision required for the pyamg solver. NOTE when using single precision arrays this should probably be on the order of 1e-5") - - tuning.add_argument("--acceleration", "-A", dest="accel", default="cg", metavar="METHOD", - help="""Acceleration method for pyamg. May be useful if it fails to converge - -Try one of: cg, gmres, fgmres, cr, cgnr, cgne, bicgstab, steepest_descent, minimal_residual""") - - test = p.add_argument_group("testing", "Options used for testing output. None of these options should be used for production runs!") - test.add_argument("--box", dest="box", action="store_true", default=False, - help="Only store the initial box solution (i.e. do not run PyAMG)") - - test.add_argument("--no-boundary-fft", action="store_false", dest="boundary_fft", default=True, - help="Once the electrode boundary conditions are solved we perform a second solution with boundaries fixed. Using this flag disables this second solution.") + p.add_argument( + f"--boundary-condition-{d}", + f"-bc-{d}", + nargs=2, + type=str, + default=["p", "p"], + metavar=("BOTTOM", "TOP"), + help=( + "Boundary condition along the {} lattice vector [periodic/p, neumann/n, dirichlet/d]. " + "Specify separate BC at the start and end of the lattice vector, respectively.".format( + n[d] + ) + ), + ) + + p.add_argument( + "--elec-V", + "-V", + action="append", + nargs=2, + metavar=("NAME", "V"), + default=[], + help="Specify chemical potential on electrode", + ) + + p.add_argument( + "--pyamg-shape", + "-ps", + nargs=3, + type=int, + metavar=("A", "B", "C"), + default=None, + help="Grid used to solve the Poisson equation, if shape is different the Grid will be interpolated (order=2) after.", + ) + + p.add_argument( + "--device", + "-D", + type=float, + default=None, + metavar="VAL", + help="Fix the value of all device atoms to a value. In some cases this turns out to yield a better box boundary. The default is to *not* fix the potential on the device atoms.", + ) + + tuning.add_argument( + "--radius", + "-R", + type=float, + default=3.0, + metavar="R", + help=( + "Radius of atoms when figuring out the electrode sizes, this corresponds to the extend of " + "each electrode where boundary conditions are fixed. Should be tuned according to the atomic species [3 Ang]" + ), + ) + + tuning.add_argument( + "--dtype", + "-d", + choices=["d", "f64", "f", "f32"], + default="d", + help="Precision of data (d/f64==double, f/f32==single)", + ) + + tuning.add_argument( + "--tolerance", + "-T", + type=float, + default=1e-7, + metavar="EPS", + help="Precision required for the pyamg solver. NOTE when using single precision arrays this should probably be on the order of 1e-5", + ) + + tuning.add_argument( + "--acceleration", + "-A", + dest="accel", + default="cg", + metavar="METHOD", + help="""Acceleration method for pyamg. May be useful if it fails to converge + +Try one of: cg, gmres, fgmres, cr, cgnr, cgne, bicgstab, steepest_descent, minimal_residual""", + ) + + test = p.add_argument_group( + "testing", + "Options used for testing output. None of these options should be used for production runs!", + ) + test.add_argument( + "--box", + dest="box", + action="store_true", + default=False, + help="Only store the initial box solution (i.e. do not run PyAMG)", + ) + + test.add_argument( + "--no-boundary-fft", + action="store_false", + dest="boundary_fft", + default=True, + help="Once the electrode boundary conditions are solved we perform a second solution with boundaries fixed. Using this flag disables this second solution.", + ) if _DEBUG: - test.add_argument("--plot", dest="plot", default=None, type=int, - help="Plot grid by averaging over the axis given as argument") - - test.add_argument("--plot-boundary", dest="plot_boundary", action="store_true", - help="Plot all 6 edges of the box with their fixed values (just before 2nd pyamg solve step)") - - p.add_argument("--out", "-o", action="append", default=None, - help="Output file to store the resulting Poisson solution. It *has* to have TSV.nc file ending to make the file conforming with TranSiesta.") + test.add_argument( + "--plot", + dest="plot", + default=None, + type=int, + help="Plot grid by averaging over the axis given as argument", + ) + + test.add_argument( + "--plot-boundary", + dest="plot_boundary", + action="store_true", + help="Plot all 6 edges of the box with their fixed values (just before 2nd pyamg solve step)", + ) + + p.add_argument( + "--out", + "-o", + action="append", + default=None, + help="Output file to store the resulting Poisson solution. It *has* to have TSV.nc file ending to make the file conforming with TranSiesta.", + ) if is_sub: p.set_defaults(runner=fftpoisson_fix_run) @@ -359,7 +510,9 @@ def fftpoisson_fix_cli(subp=None): def fftpoisson_fix_run(args): if args.out is None: - print(f">\n>\n>{_script}: No out-files has been specified, work will be carried out but not saved!\n>\n>\n") + print( + f">\n>\n>{_script}: No out-files has been specified, work will be carried out but not saved!\n>\n>\n" + ) # Read in geometry geometry = si.get_sile(args.geometry).read_geometry() @@ -368,7 +521,9 @@ def fftpoisson_fix_run(args): elecs_V = {} if len(args.elec_V) == 0: print(geometry.names) - raise ValueError(f"{_script}: Please specify all electrode potentials using --elec-V") + raise ValueError( + f"{_script}: Please specify all electrode potentials using --elec-V" + ) for name, V in args.elec_V: elecs_V[name] = float(V) @@ -390,18 +545,29 @@ def fftpoisson_fix_run(args): boundary.append(args.boundary_condition_b) boundary.append(args.boundary_condition_c) - V = solve_poisson(geometry, shape, radius=args.radius, boundary=boundary, - dtype=dtype, tolerance=args.tolerance, box=args.box, - accel=args.accel, boundary_fft=args.boundary_fft, - device_val=args.device, plot_boundary=args.plot_boundary, - **elecs_V) + V = solve_poisson( + geometry, + shape, + radius=args.radius, + boundary=boundary, + dtype=dtype, + tolerance=args.tolerance, + box=args.box, + accel=args.accel, + boundary_fft=args.boundary_fft, + device_val=args.device, + plot_boundary=args.plot_boundary, + **elecs_V, + ) if _DEBUG: if not args.plot is None: dat = V.average(args.plot) import matplotlib.pyplot as plt + axs = [ - np.linspace(0, V.lattice.length[ax], shape, endpoint=False) for ax, shape in enumerate(V.shape) + np.linspace(0, V.lattice.length[ax], shape, endpoint=False) + for ax, shape in enumerate(V.shape) ] idx = list(range(3)) diff --git a/tools/changelog.py b/tools/changelog.py index ff8b65b0bd..26d01c0793 100644 --- a/tools/changelog.py +++ b/tools/changelog.py @@ -84,7 +84,7 @@ def get_authors(revision_range): def get_commit_date(repo, rev): - """ Retrive the object that defines the revision """ + """Retrive the object that defines the revision""" return datetime.datetime.fromtimestamp(repo.commit(rev).committed_date) @@ -121,6 +121,7 @@ def get_pull_requests(repo, revision_range): pass return prs + def read_changelog(prior_rel, current_rel, format="md"): # rst search for item md_item = re.compile(r"^\s*-") @@ -132,8 +133,7 @@ def read_changelog(prior_rel, current_rel, format="md"): # for getting the date date = None out = [] - for line in open("../CHANGELOG.md", 'r'): - + for line in open("../CHANGELOG.md", "r"): # ensure no tabs are present line = line.replace("\t", " ") @@ -187,12 +187,12 @@ def read_changelog(prior_rel, current_rel, format="md"): date = date2format(datetime.date(*[int(x) for x in date.split("-")])) except ValueError: pass - + return "".join(out).strip(), date def date2format(date): - """ Convert the date to the output format we require """ + """Convert the date to the output format we require""" date = date.strftime("%d of %B %Y") if date[0] == "0": date = date[1:] @@ -271,7 +271,9 @@ def main(token, revision_range, format="md"): from argparse import ArgumentParser parser = ArgumentParser(description="Generate author/pr lists for release") - parser.add_argument("--format", choices=("md", "rst"), help="which format to write out in") + parser.add_argument( + "--format", choices=("md", "rst"), help="which format to write out in" + ) parser.add_argument("token", help="github access token") parser.add_argument("revision_range", help="..") args = parser.parse_args() diff --git a/tools/codata.py b/tools/codata.py index e30983eaba..c5eddb3261 100644 --- a/tools/codata.py +++ b/tools/codata.py @@ -9,7 +9,7 @@ def parse_line(line): - """ Parse a data line returning + """Parse a data line returning name, value, (error), (unit) """ @@ -18,7 +18,7 @@ def parse_line(line): if len(lines) in (3, 4): lines[1] = float(lines[1].replace(" ", "").replace("...", "")) if "exact" in lines[2]: - lines[2] = 0. + lines[2] = 0.0 else: lines[2] = float(lines[2].replace(" ", "")) return lines @@ -45,7 +45,8 @@ class Constant: unit: str = "" def __str__(self): - return f"#: {self.doc} [{self.unit}]\n{self.name} = PhysicalConstant({self.value}, \"{self.unit}\")" + return f'#: {self.doc} [{self.unit}]\n{self.name} = PhysicalConstant({self.value}, "{self.unit}")' + CONSTANTS = { "speed of light in vacuum": Constant("Speed of light in vacuum", "c"), @@ -58,7 +59,6 @@ def __str__(self): } - def read_file(f): fh = open(f) start = False @@ -67,38 +67,38 @@ def read_file(f): unit_table = { "mass": { "DEFAULT": "amu", - "kg": 1., - "g": 1.e-3, + "kg": 1.0, + "g": 1.0e-3, }, "length": { "DEFAULT": "Ang", - "m": 1., + "m": 1.0, "cm": 0.01, - "nm": 1.e-9, - "Ang": 1.e-10, - "pm": 1.e-12, - "fm": 1.e-15, + "nm": 1.0e-9, + "Ang": 1.0e-10, + "pm": 1.0e-12, + "fm": 1.0e-15, }, "time": { "DEFAULT": "fs", - "s": 1., - "ns": 1.e-9, - "ps": 1.e-12, - "fs": 1.e-15, - "min": 60., - "hour": 3600., - "day": 86400., + "s": 1.0, + "ns": 1.0e-9, + "ps": 1.0e-12, + "fs": 1.0e-15, + "min": 60.0, + "hour": 3600.0, + "day": 86400.0, }, "energy": { "DEFAULT": "eV", - "J": 1., - "erg": 1.e-7, + "J": 1.0, + "erg": 1.0e-7, "K": 1.380648780669e-23, }, "force": { "DEFAULT": "eV/Ang", - "N": 1., - } + "N": 1.0, + }, } constants = [] @@ -107,7 +107,8 @@ def read_file(f): if "-----" in line: start = True continue - if not start: continue + if not start: + continue name, value, *error_unit = parse_line(line) @@ -116,7 +117,7 @@ def read_file(f): unit_table[entry][key] = value if key in ("Ry", "eV", "Ha"): - unit_table[entry][f"m{key}"] = value/1000 + unit_table[entry][f"m{key}"] = value / 1000 if name in CONSTANTS: c = CONSTANTS[name] @@ -128,16 +129,25 @@ def read_file(f): constants.append(c) if c.name == "h": - c = c.__class__(f"Reduced {c.doc}", "hbar", c.value / (2 * np.pi), c.unit) + c = c.__class__( + f"Reduced {c.doc}", "hbar", c.value / (2 * np.pi), c.unit + ) constants.append(c) # Clarify force - unit_table["force"][f"eV/Ang"] = unit_table["energy"]["eV"] / unit_table["length"]["Ang"] - unit_table["force"][f"Ry/Bohr"] = unit_table["energy"]["Ry"] / unit_table["length"]["Bohr"] - unit_table["force"][f"Ha/Bohr"] = unit_table["energy"]["Ha"] / unit_table["length"]["Bohr"] + unit_table["force"][f"eV/Ang"] = ( + unit_table["energy"]["eV"] / unit_table["length"]["Ang"] + ) + unit_table["force"][f"Ry/Bohr"] = ( + unit_table["energy"]["Ry"] / unit_table["length"]["Bohr"] + ) + unit_table["force"][f"Ha/Bohr"] = ( + unit_table["energy"]["Ha"] / unit_table["length"]["Bohr"] + ) return unit_table, constants + ut, cs = read_file(sys.argv[1]) from pprint import PrettyPrinter From 85c14bab954271c87f7c0186201824df545be0b1 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 10:29:17 +0100 Subject: [PATCH 03/10] added black to blame ignore as well as the GA Signed-off-by: Nick Papior --- .git-blame-ignore-revs | 3 +++ .github/workflows/black.yml | 13 +++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 .git-blame-ignore-revs create mode 100644 .github/workflows/black.yml diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000..b07f590e14 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,3 @@ +# Add here the commits that should be ignored +# when issuing blames +0ff17210003a85b08f683307e9a73555142c8707 diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000000..46d54f5502 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,13 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + jupyter: true + From 5b2f49964f9a6ac0c7dced390e6cc554035fcb9f Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 10:36:04 +0100 Subject: [PATCH 04/10] added black badge Signed-off-by: Nick Papior --- CHANGELOG.md | 1 + CONTRIBUTING.md | 2 ++ README.md | 1 + 3 files changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 003fe67383..d2c621e6fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ we hit release version 1.0.0. - fixed cases where `Geometry.close` would not catch all neighbours, #633 ### Changed +- sisl now enforces the black style - `Lattice` now holds the boundary conditions (not `Grid`), see #626 - Some siles exposed certain properties containing basic information about the content, say number of atoms/orbitals etc. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 12738a788d..46b1e95c95 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,6 +15,8 @@ you should add this change to your `.git/config`, or in your global `.gitconfig` [filter "strip-notebook-output"] clean = "jupyter nbconvert --ClearOutputPreprocessor.enabled=True --to=notebook --stdin --stdout --log-level=ERROR" +We also enforce the black style, please run black before committing. + ## First-time contributors Add a comment on the issue and wait for the issue to be assigned before you start working on it. This helps to avoid multiple people working on similar issues. diff --git a/README.md b/README.md index eef4588360..d8ae73ae1b 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![Install sisl using PyPI](https://badge.fury.io/py/sisl.svg)](https://pypi.org/project/sisl) [![Install sisl using conda](https://anaconda.org/conda-forge/sisl/badges/version.svg)](https://anaconda.org/conda-forge/sisl) [![License: MPL 2.0](https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg)](https://www.mozilla.org/en-US/MPL/2.0/) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![DOI for citation](https://zenodo.org/badge/doi/10.5281/zenodo.597181.svg)](https://doi.org/10.5281/zenodo.597181) [![Join discussion on Discord](https://img.shields.io/discord/742636379871379577.svg?label=&logo=discord&logoColor=ffffff&color=green&labelColor=red)](https://discord.gg/5XnFXFdkv2) From 04d8b46c4513310a89101256853f6d4fdef92560 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 10:40:45 +0100 Subject: [PATCH 05/10] fixed tabs in workflow Signed-off-by: Nick Papior --- .github/workflows/black.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 46d54f5502..bfc6712473 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -8,6 +8,5 @@ jobs: steps: - uses: actions/checkout@v4 - uses: psf/black@stable - with: - jupyter: true - + with: + jupyter: true From 12f6f94554d629da6383ae46405869771ad40ce1 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 12:23:57 +0100 Subject: [PATCH 06/10] added isort to linter Signed-off-by: Nick Papior --- .github/workflows/{black.yml => linter.yml} | 1 + 1 file changed, 1 insertion(+) rename .github/workflows/{black.yml => linter.yml} (82%) diff --git a/.github/workflows/black.yml b/.github/workflows/linter.yml similarity index 82% rename from .github/workflows/black.yml rename to .github/workflows/linter.yml index bfc6712473..e077a30c56 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/linter.yml @@ -10,3 +10,4 @@ jobs: - uses: psf/black@stable with: jupyter: true + - uses: isort/isort-action@master From 80edad570a95839908eff6e2acee8ab36000cc85 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 12:36:28 +0100 Subject: [PATCH 07/10] added black to pr template Signed-off-by: Nick Papior --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 227ef8c5bf..ea92251c0b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,6 +2,6 @@ - [ ] Closes #xxxx - [ ] Tests added - - [ ] Ranned `isort .` at top--level + - [ ] Ranned `isort .` and `black .` at top--level - [ ] Documentation for functionality, `docs/` - [ ] Changes documented, `CHANGELOG.md` From 26c2d32f95b2974c22cf3dbf21d22d988127dbe8 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 13:10:18 +0100 Subject: [PATCH 08/10] trying to chain GA Signed-off-by: Nick Papior --- .github/PULL_REQUEST_TEMPLATE.md | 10 +++++----- .github/workflows/linter.yml | 10 +++++++++- .github/workflows/test.yaml | 17 ++++++++++------- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index ea92251c0b..a94ce42fb5 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,7 +1,7 @@ - - [ ] Closes #xxxx - - [ ] Tests added - - [ ] Ranned `isort .` and `black .` at top--level - - [ ] Documentation for functionality, `docs/` - - [ ] Changes documented, `CHANGELOG.md` + - [ ] Closes #x + - [ ] Added tests for new/changed functions? + - [ ] Ran `isort .` and `black .` at top-level + - [ ] Documentation for functionality in `docs/` + - [ ] Changes documented in `CHANGELOG.md` diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index e077a30c56..a9afc49aa1 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -1,6 +1,14 @@ name: Lint -on: [push, pull_request] +on: + push: + paths: + - '**.py' + - '**.ipynb' + pull_request: + paths: + - '**.py' + - '**.ipynb' jobs: lint: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2d5404c2e8..43bf214abf 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -10,9 +10,9 @@ env: on: - pull_request: - # all pull-requests on to main - branches: [main] + workflow_run: + workflows: [Lint] + types: [completed] schedule: # only once every 4 days # We can always force run this. @@ -32,8 +32,11 @@ on: jobs: # Define a few jobs that can be runned - check_schedule: - if: ${{ github.event_name == 'schedule' }} && ${{ github.actor != 'dependabot[bot]' }} + check_if_runnable: + if: | + github.event_name == 'schedule' + && github.actor != 'dependabot[bot]' + && github.event.workflow_run.conclusion == 'success' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -42,8 +45,8 @@ jobs: - run: test -n $(git rev-list --after="1 week" --max-count=1 ${{ github.sha }}) test_runs: - needs: [check_schedule] - if: ${{ always() && (contains(needs.*.result, 'success') || contains(needs.*.result, 'skipped')) }} + needs: [check_if_runnable] + if: ${{ contains(needs.*.result, 'success') || contains(needs.*.result, 'skipped') }} runs-on: ${{ matrix.os }} strategy: matrix: From ccd5f7ece0202bd0a504eff2db16402886b90407 Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 13:12:10 +0100 Subject: [PATCH 09/10] fixed tabs Signed-off-by: Nick Papior --- .github/workflows/test.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 43bf214abf..45e6a63ebe 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -34,9 +34,9 @@ jobs: # Define a few jobs that can be runned check_if_runnable: if: | - github.event_name == 'schedule' - && github.actor != 'dependabot[bot]' - && github.event.workflow_run.conclusion == 'success' + github.event_name == 'schedule' + && github.actor != 'dependabot[bot]' + && github.event.workflow_run.conclusion == 'success' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 From d79400dbe199e724381de1ed2aec0b534e0a506d Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Fri, 3 Nov 2023 13:17:42 +0100 Subject: [PATCH 10/10] fixed naming in bug template title Signed-off-by: Nick Papior --- .github/ISSUE_TEMPLATE/005_bug.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/005_bug.md b/.github/ISSUE_TEMPLATE/005_bug.md index 71d71dc450..394a770e9b 100644 --- a/.github/ISSUE_TEMPLATE/005_bug.md +++ b/.github/ISSUE_TEMPLATE/005_bug.md @@ -6,7 +6,7 @@ about: Let us know if something went wrong **Describe the bug** -**Reproducable code** +**Code to reproduce problem** ```python ```