{ "cells": [ { "cell_type": "markdown", "id": "2b1e109f-b2f5-42a8-bcc3-0acf2d7af9a0", "metadata": {}, "source": [ "# modelforge.curate : Record and SourceDataset\n", "\n", "This notebook focuses on functionality within the `Records` and `SourceDataset` classes." ] }, { "cell_type": "code", "execution_count": 1, "id": "6cf4ef4a-d9b9-49bc-bd8a-3f5e62953c5c", "metadata": {}, "outputs": [], "source": [ "from modelforge.curate import Record, SourceDataset\n", "from modelforge.utils.units import GlobalUnitSystem\n", "from modelforge.curate import AtomicNumbers, Positions, Energies, Forces, MetaData\n", "\n", "from openff.units import unit\n", "\n", "import numpy as np" ] }, { "cell_type": "markdown", "id": "bcfb4ef7-18f9-41dd-9b08-14bb8afd88d7", "metadata": {}, "source": [ "## Initializating records and datasets\n", "To start, we will create a new instance of the `SourceDataset` class to store the dataset. We will populate this with 10 records, each with 3 configurations." ] }, { "cell_type": "code", "execution_count": 2, "id": "bad426d0-eb06-44b3-8ec5-7bb7cb016e0a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-08-25 12:06:03.587\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mmodelforge.curate.sourcedataset\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m66\u001b[0m - \u001b[33m\u001b[1mDatabase file test_dataset.sqlite already exists in ./. Removing it.\u001b[0m\n" ] } ], "source": [ "new_dataset = SourceDataset(name=\"test_dataset\")\n", "\n", "for i in range(0,10):\n", " record = Record(f\"mol_{i}\")\n", " \n", " atomic_numbers = AtomicNumbers(value=np.array([[1], [6]]))\n", "\n", " positions = Positions(\n", " value=np.array([[[i, 1.0, 1.0], [2.0, 2.0, 2.0]],\n", " [[i, 2.0, 1.0], [2.0, 2.0, 2.0]],\n", " [[i, 3.0, 1.0], [2.0, 2.0, 2.0]]]), \n", " units=\"nanometer\"\n", " )\n", " \n", " total_energies = Energies(\n", " name=\"total_energies\",\n", " value=np.array([[i], \n", " [i+0.1], \n", " [i+0.2]]), \n", " units=unit.hartree\n", " )\n", " forces = Forces(\n", " name=\"forces\",\n", " value=np.array([[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]],\n", " [[10.0, 2.0, 1.0], [2.0, 2.0, 2.0]],\n", " [[20.0, 3.0, 1.0], [2.0, 2.0, 2.0]]]), \n", " units = unit.kilocalorie_per_mole/unit.nanometer,\n", " )\n", " record.add_properties([atomic_numbers, positions, total_energies, forces])\n", " new_dataset.add_record(record)" ] }, { "cell_type": "markdown", "id": "e09ca833-4b76-4b8e-8ed8-632069fa0470", "metadata": {}, "source": [ "### Examining the dataset\n", "Let us examine the dataset:" ] }, { "cell_type": "code", "execution_count": 3, "id": "ce4af77b-d32f-41a9-a85f-54c557505121", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total configurations: 30\n", "total records: 10\n", "dataset summary:\n", "{'name': 'test_dataset',\n", " 'properties': {'atomic_numbers': {'classification': 'atomic_numbers'},\n", " 'forces': {'classification': 'per_atom',\n", " 'units': 'kilojoule_per_mole / nanometer'},\n", " 'positions': {'classification': 'per_atom',\n", " 'units': 'nanometer'},\n", " 'total_energies': {'classification': 'per_system',\n", " 'units': 'kilojoule_per_mole'}},\n", " 'total_configurations': 30,\n", " 'total_records': 10}\n" ] } ], "source": [ "print(\"total configurations: \", new_dataset.total_configs())\n", "print(\"total records: \", new_dataset.total_records())\n", "\n", "import pprint\n", "print(\"dataset summary:\")\n", "pprint.pprint(new_dataset.generate_dataset_summary())\n" ] }, { "cell_type": "markdown", "id": "29daef5b-7fdf-44cb-aed0-40399490cad2", "metadata": {}, "source": [ "### Extracting/Updating records\n", "\n", "#### Print a record\n", "We can can print out the summary of any invidual record using the `print_record` function. " ] }, { "cell_type": "code", "execution_count": 4, "id": "3fb55bf7-184e-413c-9a1b-fcdf58a850d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: mol_0\n", "* n_atoms: 2\n", "* n_configs: 3\n", "* atomic_numbers:\n", " - name='atomic_numbers' value=array([[1],\n", " [6]]) units= classification='atomic_numbers' property_type='atomic_numbers' n_configs=None n_atoms=2\n", "* per-atom properties: (['positions', 'forces']):\n", " - name='positions' value=array([[[0., 1., 1.],\n", " [2., 2., 2.]],\n", "\n", " [[0., 2., 1.],\n", " [2., 2., 2.]],\n", "\n", " [[0., 3., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='length' n_configs=3 n_atoms=2\n", " - name='forces' value=array([[[ 1., 1., 1.],\n", " [ 2., 2., 2.]],\n", "\n", " [[10., 2., 1.],\n", " [ 2., 2., 2.]],\n", "\n", " [[20., 3., 1.],\n", " [ 2., 2., 2.]]]) units= classification='per_atom' property_type='force' n_configs=3 n_atoms=2\n", "* per-system properties: (['total_energies']):\n", " - name='total_energies' value=array([[0. ],\n", " [0.1],\n", " [0.2]]) units= classification='per_system' property_type='energy' n_configs=3 n_atoms=None\n", "* meta_data: ([])\n", "\n" ] } ], "source": [ "new_dataset.print_record(\"mol_0\")" ] }, { "cell_type": "markdown", "id": "82e0d970-615e-4c72-a4d2-07e85a1a48f6", "metadata": {}, "source": [ "#### Extract a copy of a record\n", "We can extract a copy of any record using the `get_record` function. " ] }, { "cell_type": "code", "execution_count": 5, "id": "977541b3-8f71-435d-8bc0-aadbd141c7a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: mol_0\n", "* n_atoms: 2\n", "* n_configs: 3\n", "* atomic_numbers:\n", " - name='atomic_numbers' value=array([[1],\n", " [6]]) units= classification='atomic_numbers' property_type='atomic_numbers' n_configs=None n_atoms=2\n", "* per-atom properties: (['positions', 'forces']):\n", " - name='positions' value=array([[[0., 1., 1.],\n", " [2., 2., 2.]],\n", "\n", " [[0., 2., 1.],\n", " [2., 2., 2.]],\n", "\n", " [[0., 3., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='length' n_configs=3 n_atoms=2\n", " - name='forces' value=array([[[ 1., 1., 1.],\n", " [ 2., 2., 2.]],\n", "\n", " [[10., 2., 1.],\n", " [ 2., 2., 2.]],\n", "\n", " [[20., 3., 1.],\n", " [ 2., 2., 2.]]]) units= classification='per_atom' property_type='force' n_configs=3 n_atoms=2\n", "* per-system properties: (['total_energies']):\n", " - name='total_energies' value=array([[0. ],\n", " [0.1],\n", " [0.2]]) units= classification='per_system' property_type='energy' n_configs=3 n_atoms=None\n", "* meta_data: ([])\n", "\n" ] } ], "source": [ "record_temp = new_dataset.get_record(\"mol_0\")\n", "print(record_temp)" ] }, { "cell_type": "markdown", "id": "a857f7df-5615-4b2a-93af-b6ce56ca7bb4", "metadata": {}, "source": [ "#### Update a record in the dataset\n", "Since `get_record` returns a copy, if the record is changed, the `update_record` function needs to be used to updated it within the dataset. Here we can add metadata to this record and update it. " ] }, { "cell_type": "code", "execution_count": 6, "id": "e846f51c-57bb-44e1-9fa2-6ec16bd8a239", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: mol_0\n", "* n_atoms: 2\n", "* n_configs: 3\n", "* atomic_numbers:\n", " - name='atomic_numbers' value=array([[1],\n", " [6]]) units= classification='atomic_numbers' property_type='atomic_numbers' n_configs=None n_atoms=2\n", "* per-atom properties: (['positions', 'forces']):\n", " - name='positions' value=array([[[0., 1., 1.],\n", " [2., 2., 2.]],\n", "\n", " [[0., 2., 1.],\n", " [2., 2., 2.]],\n", "\n", " [[0., 3., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='length' n_configs=3 n_atoms=2\n", " - name='forces' value=array([[[ 1., 1., 1.],\n", " [ 2., 2., 2.]],\n", "\n", " [[10., 2., 1.],\n", " [ 2., 2., 2.]],\n", "\n", " [[20., 3., 1.],\n", " [ 2., 2., 2.]]]) units= classification='per_atom' property_type='force' n_configs=3 n_atoms=2\n", "* per-system properties: (['total_energies']):\n", " - name='total_energies' value=array([[0. ],\n", " [0.1],\n", " [0.2]]) units= classification='per_system' property_type='energy' n_configs=3 n_atoms=None\n", "* meta_data: (['smiles'])\n", " - name='smiles' value='[CH]' units= classification='meta_data' property_type='meta_data' n_configs=None n_atoms=None\n", "\n" ] } ], "source": [ "smiles = MetaData(name='smiles', value='[CH]')\n", "\n", "record_temp.add_property(smiles)\n", "\n", "new_dataset.update_record(record_temp)\n", "\n", "new_dataset.print_record(\"mol_0\")" ] }, { "cell_type": "markdown", "id": "606ca8f2-ecae-48e2-af72-0f513d744a71", "metadata": {}, "source": [ "#### Removing a record from a dataset\n", "\n", "We can remove a record using the `remove_record` function in the `SourceDataset` class. " ] }, { "cell_type": "code", "execution_count": 7, "id": "432734bd-e3f5-4685-91ed-dfa62e0f0604", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total_records: 10\n", "total_records: 9\n" ] } ], "source": [ "print(\"total_records: \", new_dataset.total_records())\n", "new_dataset.remove_record(\"mol_9\")\n", "print(\"total_records: \", new_dataset.total_records())" ] }, { "cell_type": "markdown", "id": "54228e48-fe63-4f53-97dd-6508b907f8d6", "metadata": {}, "source": [ "#### Slicing a record\n", "\n", "We can slice a record, returning a copy of the record that only includes subset of configurations. This will be applied to all properties with the record. \n", "\n", "This can be done at the level of a record or called via a wrapping function in the dataset. \n", "\n", "the code below will return the first 2 records out of the 3 total. " ] }, { "cell_type": "code", "execution_count": 8, "id": "03c4e8b9-5549-47a8-ae8c-c723598189d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: mol_0\n", "* n_atoms: 2\n", "* n_configs: 1\n", "* atomic_numbers:\n", " - name='atomic_numbers' value=array([[1],\n", " [6]]) units= classification='atomic_numbers' property_type='atomic_numbers' n_configs=None n_atoms=2\n", "* per-atom properties: (['positions', 'forces']):\n", " - name='positions' value=array([[[0., 1., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='length' n_configs=1 n_atoms=2\n", " - name='forces' value=array([[[1., 1., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='force' n_configs=1 n_atoms=2\n", "* per-system properties: (['total_energies']):\n", " - name='total_energies' value=array([[0.]]) units= classification='per_system' property_type='energy' n_configs=1 n_atoms=None\n", "* meta_data: (['smiles'])\n", " - name='smiles' value='[CH]' units= classification='meta_data' property_type='meta_data' n_configs=None n_atoms=None\n", "\n" ] } ], "source": [ "record_sliced = record_temp.slice_record(min=0, max=1)\n", "\n", "print(record_sliced)" ] }, { "cell_type": "code", "execution_count": 9, "id": "bb9da9a4-361b-4cda-a452-b256508ac249", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: mol_0\n", "* n_atoms: 2\n", "* n_configs: 1\n", "* atomic_numbers:\n", " - name='atomic_numbers' value=array([[1],\n", " [6]]) units= classification='atomic_numbers' property_type='atomic_numbers' n_configs=None n_atoms=2\n", "* per-atom properties: (['positions', 'forces']):\n", " - name='positions' value=array([[[0., 1., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='length' n_configs=1 n_atoms=2\n", " - name='forces' value=array([[[1., 1., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='force' n_configs=1 n_atoms=2\n", "* per-system properties: (['total_energies']):\n", " - name='total_energies' value=array([[0.]]) units= classification='per_system' property_type='energy' n_configs=1 n_atoms=None\n", "* meta_data: (['smiles'])\n", " - name='smiles' value='[CH]' units= classification='meta_data' property_type='meta_data' n_configs=None n_atoms=None\n", "\n" ] } ], "source": [ "record_sliced = new_dataset.slice_record(\"mol_0\", min=0, max=1)\n", "print(record_sliced)" ] }, { "cell_type": "markdown", "id": "00063d6e-821f-46f1-a89c-1e32ea4fa174", "metadata": {}, "source": [ "#### Limiting to a subset of atomic numbers\n", "\n", "We can query if a record contains atomic numbers within a specified set using the `contains_atomic_numbers` in the `Records` class. This will return true if the atomic numbers in the record are all represented in the provided array and false if any atomic numbers in the record are not included in the provided array. \n", "\n", "Note, this function will not typically need to be called directly, as the `subset_dataset` function in the `SourceDataset` provides a wrapper for this functionality on the entire dataset (discussed separately later). " ] }, { "cell_type": "code", "execution_count": 10, "id": "f42dac4d-d7ff-4b78-baf4-49ed178d4373", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "record_temp.contains_atomic_numbers(np.array([1,6]))" ] }, { "cell_type": "code", "execution_count": 11, "id": "b0600fb3-c8fb-4ffd-a57b-bc2dcf3e29c4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "record_temp.contains_atomic_numbers(np.array([1,8]))" ] }, { "cell_type": "markdown", "id": "1a227658-65c3-4757-b7fe-d2395169bc3e", "metadata": {}, "source": [ "#### Removing high force configurations\n", "\n", "Often, we wish to remove configurations with very high forces. The `remove_high_force_configs` function in the `Records` class can be used to return a copy of the record, excluding those configurations where the magnitude of the force exceeds the specified threshold. By default, this will filter using the name \"forces\" (i.e., it will look for a property with name \"forces\" within the record); this can be toggled if the force property is named differently. \n", "\n", "Note, this function will not typically need to be called directly, as the `subset_dataset` function in the `SourceDataset` provides a wrapper for this functionality on the entire dataset (discussed separately later). \n", "\n", "For example, below we can filter out any configurations with a force greater than 15, which will eliminate the last configuration of the record (see initialization above). " ] }, { "cell_type": "code", "execution_count": 12, "id": "7570d75b-b5ac-4862-82e0-d8c37a01ef99", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: mol_0\n", "* n_atoms: 2\n", "* n_configs: 2\n", "* atomic_numbers:\n", " - name='atomic_numbers' value=array([[1],\n", " [6]]) units= classification='atomic_numbers' property_type='atomic_numbers' n_configs=None n_atoms=2\n", "* per-atom properties: (['positions', 'forces']):\n", " - name='positions' value=array([[[0., 1., 1.],\n", " [2., 2., 2.]],\n", "\n", " [[0., 2., 1.],\n", " [2., 2., 2.]]]) units= classification='per_atom' property_type='length' n_configs=2 n_atoms=2\n", " - name='forces' value=array([[[ 1., 1., 1.],\n", " [ 2., 2., 2.]],\n", "\n", " [[10., 2., 1.],\n", " [ 2., 2., 2.]]]) units= classification='per_atom' property_type='force' n_configs=2 n_atoms=2\n", "* per-system properties: (['total_energies']):\n", " - name='total_energies' value=array([[0. ],\n", " [0.1]]) units= classification='per_system' property_type='energy' n_configs=2 n_atoms=None\n", "* meta_data: (['smiles'])\n", " - name='smiles' value='[CH]' units= classification='meta_data' property_type='meta_data' n_configs=None n_atoms=None\n", "\n" ] } ], "source": [ "record_max_force = record_temp.remove_high_force_configs(unit.Quantity(15, unit.kilocalorie_per_mole/unit.nanometer), \"forces\")\n", "\n", "print(record_max_force)" ] }, { "cell_type": "markdown", "id": "aae624b3-f5f1-4543-990d-be2a79697b04", "metadata": {}, "source": [ "### Subsetting a dataset\n", "\n", "`SourceDataset` includes a function called `subset_dataset` which returns a copy of the dataset with various filters applied. The filters that can be applied include:\n", "\n", "- total_records: Maximum number of records to include in the subset.\n", "- total_configurations: Total number of conformers to include in the subset.\n", "- max_configurations_per_record: Maximum number of conformers to include per record. If None, all conformers in a record will be included.\n", "- atomic_numbers_to_limit: An array of atomic species to limit the dataset to. Any molecules that contain elements outside of this list will be igonored\n", "- max_force: If set, configurations with forces greater than this value will be removed.\n", "- final_configuration_only: If True, only the final configuration of each record will be included in the subset.\n", "- max_configurations_per_record_order: Can be \"start\", \"end\", or \"random\", which defines whether configurations are taking from the start of the underlying array, the end of the array, or randomly chosen, respectively. Note, users can also pass the \"seed\" used to initialize the random number generator used to perform random record selection to ensure reproducibility and/or generate unique subsets.\n", "\n", "Note, `total_records` and `total_configurations` can not be used in conjunction. \n", "\n", "Below, we create a new dataset that will limit to a max number of 2 configurations per record, and a total of 10 total configurations. " ] }, { "cell_type": "code", "execution_count": 13, "id": "95f7282e-f04c-4086-9756-3f74d6b0b081", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-08-25 12:09:41.588\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mmodelforge.curate.sourcedataset\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m66\u001b[0m - \u001b[33m\u001b[1mDatabase file dataset_subset.sqlite already exists in ./. Removing it.\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "5\n", "10\n" ] } ], "source": [ "dataset_subset = new_dataset.subset_dataset(new_dataset_name=\"dataset_subset\", total_configurations=10, max_configurations_per_record=2)\n", "\n", "print(dataset_subset.total_records())\n", "print(dataset_subset.total_configs())" ] }, { "cell_type": "markdown", "id": "664d3e02-4c98-452f-bd74-e62d2587a4ee", "metadata": {}, "source": [ "## SourceDataset backend sqlite database\n", "\n", "The `SourceDataset` class stores records within a sqlite database rather than in memory. The name and location of this database can be set at instantiation of the dataset. If these are not set, the default localation will be \"./\" and the database will be named based upon the name of the dataset (replacing any spaces with an underscore). The code below would produce the same dataset as the default if no values were provided. " ] }, { "cell_type": "code", "execution_count": 14, "id": "ad4cc5e1-8971-4117-89aa-8bb6a6768a2f", "metadata": {}, "outputs": [], "source": [ "new_dataset2 = SourceDataset(name=\"new dataset2\", local_db_dir=\"./\", local_db_name=\"new_dataset2.sqlite\")" ] }, { "cell_type": "markdown", "id": "b887d70b-5825-480c-adb4-32cdf60db92b", "metadata": {}, "source": [ "The use of a sqlite backend not only reduces the memory footprint, but also allows a dataset to be loaded from an existing database. Being able to load from the database allows us to avoid having to go through the processing of a dataset (i.e., setting up individual properties, Records, etc.). \n", "\n", "The following code will load up the subsetted dataset generated in the prior cells:" ] }, { "cell_type": "code", "execution_count": 15, "id": "563ab11a-20d5-4a30-b079-64ecd8318e8f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5\n", "10\n" ] } ], "source": [ "new_dataset2 = SourceDataset(name=\"new dataset2\", local_db_dir=\"./\", local_db_name=\"dataset_subset.sqlite\", read_from_local_db=True)\n", "\n", "print(new_dataset2.total_records())\n", "print(new_dataset2.total_configs())" ] }, { "cell_type": "markdown", "id": "e4783fbb-13ef-476a-8f34-5e9b128b3155", "metadata": {}, "source": [ "When subsetting a dataset, we can also specify the name and location of the database that will be generated. Otherwise, the same default behavior is used (i.e., based on dataset name). This function will return an error if the new and old datasets have the same name. " ] }, { "cell_type": "code", "execution_count": null, "id": "6342ce9e-00d5-40a1-8946-c46f275263b9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }