假设你想修改 print_report() 函数,以支持各种不同的输出格式,例如纯文本,HTML, CSV,或者 XML。为此,你可以尝试编写一个庞大的函数来实现每一个功能。但是,这样做可能会导致代码非常混乱,无法维护。这是一个使用继承的绝佳机会。
首先,请关注创建表所涉及的步骤。在表的顶部是标题。标题的后面是数据行。让我们使用这些步骤把它们放到各自的类中吧。创建一个名为 tableformat.py 的文件,并定义以下类:
# tableformat.py class TableFormatter: def headings(self, headers): ''' Emit the table headings. ''' raise NotImplementedError() def row(self, rowdata): ''' Emit a single row of table data. ''' raise NotImplementedError()除了稍后用作定义其它类的设计规范,该类什么也不做。有时候,这样的类被称为“抽象基类”。
请修改 print_report() 函数,使其接受一个 TableFormatter 对象作为输入,并执行 TableFormatter 的方法来生成输出。示例:
# report.py ... def print_report(reportdata, formatter): ''' Print a nicely formated table from a list of (name, shares, price, change) tuples. ''' formatter.headings(['Name','Shares','Price','Change']) for name, shares, price, change in reportdata: rowdata = [ name, str(shares), f'{price:0.2f}', f'{change:0.2f}' ] formatter.row(rowdata)因为你在 portfolio_report() 函数中增加了一个参数,所以你也需要修改 portfolio_report() 函数。请修改 portfolio_report() 函数,以便像下面这样创建 TableFormatter:
# report.py import tableformat ... def portfolio_report(portfoliofile, pricefile): ''' Make a stock report given portfolio and price data files. ''' # Read data files portfolio = read_portfolio(portfoliofile) prices = read_prices(pricefile) # Create the report data report = make_report_data(portfolio, prices) # Print it out formatter = tableformat.TableFormatter() print_report(report, formatter)运行新代码:
>>> ================================ RESTART ================================ >>> import report >>> report.portfolio_report('Data/portfolio.csv', 'Data/prices.csv') ... crashes ...程序应该会马上崩溃,并附带一个 NotImplementedError 异常。虽然这没有那么令人兴奋,但是结果确实是我们期待的。继续下一步部分。
练习 4.6:使用继承生成不同的输出在 a 部分定义的 TableFormatter 类旨在通过继承进行扩展。实际上,这就是整个思想。要说明这点,请像下面这样定义 TextTableFormatter 类:
# tableformat.py ... class TextTableFormatter(TableFormatter): ''' Emit a table in plain-text format ''' def headings(self, headers): for h in headers: print(f'{h:>10s}', end=' ') print() print(('-'*10 + ' ')*len(headers)) def row(self, rowdata): for d in rowdata: print(f'{d:>10s}', end=' ') print()请像下面这样修改 portfolio_report() 函数:
# report.py ... def portfolio_report(portfoliofile, pricefile): ''' Make a stock report given portfolio and price data files. ''' # Read data files portfolio = read_portfolio(portfoliofile) prices = read_prices(pricefile) # Create the report data report = make_report_data(portfolio, prices) # Print it out formatter = tableformat.TextTableFormatter() print_report(report, formatter)这应该会生成和之前一样的输出:
>>> ================================ RESTART ================================ >>> import report >>> report.portfolio_report('Data/portfolio.csv', 'Data/prices.csv') Name Shares Price Change ---------- ---------- ---------- ---------- AA 100 9.22 -22.98 IBM 50 106.28 15.18 CAT 150 35.46 -47.98 MSFT 200 20.89 -30.34 GE 95 13.48 -26.89 MSFT 50 20.89 -44.21 IBM 100 106.28 35.84 >>>但是,让我们更改输出为其它内容。定义一个以 CSV 格式生成输出的 CSVTableFormatter。
# tableformat.py ... class CSVTableFormatter(TableFormatter): ''' Output portfolio data in CSV format. ''' def headings(self, headers): print(','.join(headers)) def row(self, rowdata): print(','.join(rowdata))请像下面这样修改主程序:
def portfolio_report(portfoliofile, pricefile): ''' Make a stock report given portfolio and price data files. ''' # Read data files portfolio = read_portfolio(portfoliofile) prices = read_prices(pricefile) # Create the report data report = make_report_data(portfolio, prices) # Print it out formatter = tableformat.CSVTableFormatter() print_report(report, formatter)然后,你应该会看到像下面这样的 CSV 输出:
>>> ================================ RESTART ================================ >>> import report >>> report.portfolio_report('Data/portfolio.csv', 'Data/prices.csv') Name,Shares,Price,Change AA,100,9.22,-22.98 IBM,50,106.28,15.18 CAT,150,35.46,-47.98 MSFT,200,20.89,-30.34 GE,95,13.48,-26.89 MSFT,50,20.89,-44.21 IBM,100,106.28,35.84