Skip to content

Commit fedcd90

Browse files
authored
Merge pull request #908 from PyThaiNLP/add-save-load-param_free
Add save and load for pythainlp.classify.param_free.GzipModel
2 parents 6e7d917 + e7a1c82 commit fedcd90

File tree

3 files changed

+82
-6
lines changed

3 files changed

+82
-6
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,5 @@ cython_debug/
117117
notebooks/iso_11940-dev.ipynb
118118

119119
# vscode devcontainer
120-
.devcontainer/
120+
.devcontainer/
121+
notebooks/d.model

notebooks/test_gzip_classify.ipynb

+56-2
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,60 @@
6060
"source": [
6161
"model.predict(\"ฉันดีใจ\", k=1)"
6262
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 5,
67+
"id": "5a97f0d3",
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"model.save(\"d.model\")"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": 6,
77+
"id": "6e183243",
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"model2 = pythainlp.classify.param_free.GzipModel(model_path=\"d.model\")"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": 7,
87+
"id": "b30af6f0",
88+
"metadata": {},
89+
"outputs": [
90+
{
91+
"data": {
92+
"text/plain": [
93+
"'Positive'"
94+
]
95+
},
96+
"execution_count": 7,
97+
"metadata": {},
98+
"output_type": "execute_result"
99+
}
100+
],
101+
"source": [
102+
"model2.predict(x1=\"ฉันดีใจ\", k=1)"
103+
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": null,
108+
"id": "3e72a33b",
109+
"metadata": {},
110+
"outputs": [],
111+
"source": []
63112
}
64113
],
65114
"metadata": {
66115
"kernelspec": {
67-
"display_name": "Python 3 (ipykernel)",
116+
"display_name": "Python 3.8.13 ('base')",
68117
"language": "python",
69118
"name": "python3"
70119
},
@@ -78,7 +127,12 @@
78127
"name": "python",
79128
"nbconvert_exporter": "python",
80129
"pygments_lexer": "ipython3",
81-
"version": "3.10.9"
130+
"version": "3.8.13"
131+
},
132+
"vscode": {
133+
"interpreter": {
134+
"hash": "a1d6ff38954a1cdba4cf61ffa51e42f4658fc35985cd256cd89123cae8466a39"
135+
}
82136
}
83137
},
84138
"nbformat": 4,

pythainlp/classify/param_free.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import gzip
66
from typing import List, Tuple
77
import numpy as np
8+
import json
89

910

1011
class GzipModel:
@@ -14,11 +15,15 @@ class GzipModel:
1415
(Jiang et al., Findings 2023)
1516
1617
:param list training_data: list [(text_sample,label)]
18+
:param str model_path: Path for loading model (if you saved the model)
1719
"""
1820

19-
def __init__(self, training_data: List[Tuple[str, str]]):
20-
self.training_data = np.array(training_data)
21-
self.Cx2_list = self.train()
21+
def __init__(self, training_data: List[Tuple[str, str]] = None, model_path: str = None):
22+
if model_path is not None:
23+
self.load(model_path)
24+
else:
25+
self.training_data = np.array(training_data)
26+
self.Cx2_list = self.train()
2227

2328
def train(self):
2429
Cx2_list = []
@@ -72,3 +77,19 @@ def predict(self, x1: str, k: int = 1) -> str:
7277
predict_class = top_k_class[counts.argmax()]
7378

7479
return predict_class
80+
81+
def save(self, path: str):
82+
"""
83+
:param str path: path for save model
84+
"""
85+
with open(path, "w") as f:
86+
json.dump({
87+
"training_data": self.training_data.tolist(),
88+
"Cx2_list": self.Cx2_list
89+
}, f, ensure_ascii=False)
90+
91+
def load(self, path: str):
92+
with open(path, "r") as f:
93+
data = json.load(f)
94+
self.Cx2_list = data["Cx2_list"]
95+
self.training_data = np.array(data["training_data"])

0 commit comments

Comments
 (0)